Implement GK sketching on GPU. (#5846)

* Implement GK sketching on GPU.
* Strong tests on quantile building.
* Handle sparse dataset by binary searching the column index.
* Hypothesis test on dask.
This commit is contained in:
Jiaming Yuan 2020-07-07 12:16:21 +08:00 committed by GitHub
parent ac3f0e78dc
commit 048d969be4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 2045 additions and 405 deletions

1
Jenkinsfile vendored
View File

@ -313,6 +313,7 @@ def TestPythonGPU(args) {
nodeReq = (args.multi_gpu) ? 'linux && mgpu' : 'linux && gpu'
node(nodeReq) {
unstash name: 'xgboost_whl_cuda10'
unstash name: 'xgboost_cpp_tests'
unstash name: 'srcs'
echo "Test Python GPU: CUDA ${args.cuda_version}"
def container_type = "gpu"

View File

@ -573,8 +573,8 @@ class Span {
XGBOOST_DEVICE auto subspan() const -> // NOLINT
Span<element_type,
detail::ExtentValue<Extent, Offset, Count>::value> {
SPAN_CHECK(Offset < size() || size() == 0);
SPAN_CHECK(Count == dynamic_extent || (Offset + Count <= size()));
SPAN_CHECK((Count == dynamic_extent) ?
(Offset <= size()) : (Offset + Count <= size()));
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
}
@ -582,9 +582,8 @@ class Span {
XGBOOST_DEVICE Span<element_type, dynamic_extent> subspan( // NOLINT
index_type _offset,
index_type _count = dynamic_extent) const {
SPAN_CHECK(_offset < size() || size() == 0);
SPAN_CHECK((_count == dynamic_extent) || (_offset + _count <= size()));
SPAN_CHECK((_count == dynamic_extent) ?
(_offset <= size()) : (_offset + _count <= size()));
return {data() + _offset, _count ==
dynamic_extent ? size() - _offset : _count};
}

View File

@ -78,6 +78,33 @@ void AllReducer::Init(int _device_ordinal) {
#endif // XGBOOST_USE_NCCL
}
void AllReducer::AllGather(void const *data, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
size_t world = rabit::GetWorldSize();
segments->clear();
segments->resize(world, 0);
segments->at(rabit::GetRank()) = length_bytes;
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0);
recvbuf->resize(total_bytes);
size_t offset = 0;
safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world; ++i) {
size_t as_bytes = segments->at(i);
safe_nccl(
ncclBroadcast(data, recvbuf->data().get() + offset,
as_bytes, ncclChar, i, comm_, stream_));
offset += as_bytes;
}
safe_nccl(ncclGroupEnd());
#endif // XGBOOST_USE_NCCL
}
AllReducer::~AllReducer() {
#ifdef XGBOOST_USE_NCCL
if (initialised_) {

View File

@ -5,10 +5,16 @@
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/device_malloc_allocator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <thrust/execution_policy.h>
#include <thrust/transform_scan.h>
#include <thrust/logical.h>
#include <thrust/gather.h>
#include <thrust/unique.h>
#include <thrust/binary_search.h>
#include <rabit/rabit.h>
@ -53,6 +59,36 @@ __device__ __forceinline__ double atomicAdd(double* address, double val) { // N
}
#endif
namespace dh {
namespace detail {
template <size_t size>
struct AtomicDispatcher;
template <>
struct AtomicDispatcher<sizeof(uint32_t)> {
using Type = unsigned int; // NOLINT
static_assert(sizeof(Type) == sizeof(uint32_t), "Unsigned should be of size 32 bits.");
};
template <>
struct AtomicDispatcher<sizeof(uint64_t)> {
using Type = unsigned long long; // NOLINT
static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits.");
};
} // namespace detail
} // namespace dh
// atomicAdd is not defined for size_t.
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> * = // NOLINT
nullptr>
T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT
using Type = typename dh::detail::AtomicDispatcher<sizeof(T)>::Type;
Type ret = ::atomicAdd(reinterpret_cast<Type *>(addr), static_cast<Type>(v));
return static_cast<T>(ret);
}
namespace dh {
#define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__
@ -291,10 +327,12 @@ public:
safe_cuda(cudaGetDevice(&current_device));
stats_.RegisterDeallocation(ptr, n, current_device);
}
size_t PeakMemory()
{
size_t PeakMemory() const {
return stats_.peak_allocated_bytes;
}
size_t CurrentlyAllocatedBytes() const {
return stats_.currently_allocated_bytes;
}
void Clear()
{
stats_ = DeviceStats();
@ -529,7 +567,6 @@ class AllReducer {
bool initialised_ {false};
size_t allreduce_bytes_ {0}; // Keep statistics of the number of bytes communicated
size_t allreduce_calls_ {0}; // Keep statistics of the number of reduce calls
std::vector<size_t> host_data_; // Used for all reduce on host
#ifdef XGBOOST_USE_NCCL
ncclComm_t comm_;
cudaStream_t stream_;
@ -569,6 +606,27 @@ class AllReducer {
#endif
}
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(void const* data, size_t length_bytes,
std::vector<size_t>* segments, dh::caching_device_vector<char>* recvbuf);
void AllGather(uint32_t const* data, size_t length,
dh::caching_device_vector<uint32_t>* recvbuf) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
size_t world = rabit::GetWorldSize();
recvbuf->resize(length * world);
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32,
comm_, stream_));
#endif // XGBOOST_USE_NCCL
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
@ -607,6 +665,40 @@ class AllReducer {
#endif
}
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
#endif
}
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
#endif
}
// Specialization for size_t, which is implementation defined so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
#endif
}
/**
* \fn void Synchronize()
*
@ -886,9 +978,86 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) -
1 - first;
return segment_id;
}
template <typename T>
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
namespace detail {
template <typename Key, typename KeyOutIt>
struct SegmentedUniqueReduceOp {
KeyOutIt key_out;
__device__ Key const& operator()(Key const& key) const {
auto constexpr kOne = static_cast<std::remove_reference_t<decltype(*(key_out + key.first))>>(1);
atomicAdd(&(*(key_out + key.first)), kOne);
return key;
}
};
} // namespace detail
/* \brief Segmented unique function. Keys are pointers to segments with key_segments_last -
* key_segments_first = n_segments + 1.
*
* \pre Input segment and output segment must not overlap.
*
* \param key_segments_first Beginning iterator of segments.
* \param key_segments_last End iterator of segments.
* \param val_first Beginning iterator of values.
* \param val_last End iterator of values.
* \param key_segments_out Output iterator of segments.
* \param val_out Output iterator of values.
*
* \return Number of unique values in total.
*/
template <typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename Comp>
size_t
SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
Comp comp) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
dh::XGBCachingDeviceAllocator<char> alloc;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
[=] __device__(size_t i) {
size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i);
return thrust::make_pair(seg, *(val_first + i));
});
size_t segments_len = key_segments_last - key_segments_first;
thrust::fill(thrust::device, key_segments_out, key_segments_out + segments_len, 0);
size_t n_inputs = std::distance(val_first, val_last);
// Reduce the number of uniques elements per segment, avoid creating an intermediate
// array for `reduce_by_key`. It's limited by the types that atomicAdd supports. For
// example, size_t is not supported as of CUDA 10.2.
auto reduce_it = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
detail::SegmentedUniqueReduceOp<Key, KeyOutIt>{key_segments_out});
auto uniques_ret = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), unique_key_it, unique_key_it + n_inputs,
val_first, reduce_it, val_out,
[=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) {
// In the same segment.
return comp(l.second, r.second);
}
return false;
});
auto n_uniques = uniques_ret.second - val_out;
CHECK_LE(n_uniques, n_inputs);
thrust::exclusive_scan(thrust::cuda::par(alloc), key_segments_out,
key_segments_out + segments_len, key_segments_out, 0);
return n_uniques;
}
} // namespace dh

View File

@ -158,7 +158,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
uint32_t beg_col, uint32_t end_col,
uint32_t thread_id) {
CHECK_GE(end_col, beg_col);
constexpr float kFactor = 8;
// Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_;
@ -175,11 +174,12 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
max_num_bins);
if (n_bins == 0) {
// cut_ptrs_ is initialized with a zero, so there's always an element at the back
CHECK_GE(local_ptrs.size(), 1);
local_ptrs.emplace_back(local_ptrs.back());
continue;
}
sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor));
sketch.Init(info.num_row_, 1.0 / (n_bins * WQSketch::kFactor));
for (auto const& entry : column) {
uint32_t weight_ind = 0;
if (use_group_ind) {
@ -329,7 +329,6 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
const MetaInfo& info = p_fmat->Info();
// safe factor for better accuracy
constexpr int kFactor = 8;
std::vector<WQSketch> sketchs;
const int nthread = omp_get_max_threads();
@ -339,7 +338,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
unsigned const ncol = static_cast<unsigned>(info.num_col_);
sketchs.resize(info.num_col_);
for (auto& s : sketchs) {
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor));
s.Init(info.num_row_, 1.0 / (max_num_bins * WQSketch::kFactor));
}
// Data groups, used in ranking.
@ -410,9 +409,8 @@ void DenseCuts::Init
// This allows efficient training on wide data
size_t global_max_rows = max_rows;
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
constexpr int kFactor = 8;
size_t intermediate_num_cuts =
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor));
std::min(global_max_rows, static_cast<size_t>(max_num_bins * WQSketch::kFactor));
// gather the histogram data
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
std::vector<WQSketch::SummaryContainer> summary_array;

View File

@ -8,6 +8,7 @@
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sort.h>
#include <thrust/binary_search.h>
@ -31,21 +32,20 @@ namespace common {
constexpr float SketchContainer::kFactor;
namespace detail {
// Count the entries in each column and exclusive scan
void ExtractCuts(int device,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
size_t cut_idx = idx - cuts_ptr[column_idx];
Span<Entry const> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
size_t rank = (column_entries.size() * cut_idx) /
@ -55,31 +55,20 @@ void ExtractCuts(int device,
});
}
/**
* \brief Extracts the cuts from sorted data, considering weights.
*
* \param device The device.
* \param cuts Output cuts.
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns.
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/
void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) {
void ExtractWeightedCutsSparse(int device,
common::Span<SketchContainer::OffsetT const> cuts_ptr,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature;
size_t column_idx = dh::SegmentId(cuts_ptr, idx);
size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts =
min(static_cast<size_t>(num_cuts_per_feature), column_size);
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;
size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
size_t cut_idx = idx - cuts_ptr[column_idx];
Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size);
@ -109,7 +98,7 @@ void ExtractWeightedCuts(int device,
max(static_cast<size_t>(0),
min(sample_idx, column_entries.size() - 1));
}
// repeated values will be filtered out on the CPU
// repeated values will be filtered out later.
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
bst_float rmax = column_weights_scan[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
@ -117,31 +106,71 @@ void ExtractWeightedCuts(int device,
});
}
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end,
SketchContainer* sketch_container, int num_cuts,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows);
}
dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
{sorted_entries.data().get(), sorted_entries.size()},
num_columns);
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns,
size_t max_bins, size_t nnz) {
auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
auto if_dense = num_columns * per_column;
auto result = std::min(nnz, if_dense);
return result;
}
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts);
ExtractCuts(device, num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
size_t num_bins, bool with_weights) {
size_t peak = 0;
// 0. Allocate cut pointer in quantile container by increasing: n_columns + 1
size_t total = (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 1. Copy and sort: 2 * bytes_per_element * shape
total += BytesPerElement(with_weights) * num_rows * num_columns;
peak = std::max(peak, total);
// 2. Deallocate bytes_per_element * shape due to reusing memory in sort.
total -= BytesPerElement(with_weights) * num_rows * num_columns / 2;
// 3. Allocate colomn size scan by increasing: n_columns + 1
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 4. Allocate cut pointer by increasing: n_columns + 1
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 5. Allocate cuts: assuming rows is greater than bins: n_columns * limit_size
total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry);
// 6. Deallocate copied entries by reducing: bytes_per_element * shape.
peak = std::max(peak, total);
total -= (BytesPerElement(with_weights) * num_rows * num_columns) / 2;
// 7. Deallocate column size scan.
peak = std::max(peak, total);
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 8. Deallocate cut size scan.
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) *
// n_columns + n_columns + n_columns + 1
total += std::min(num_rows, num_bins) * num_columns * sizeof(float);
total += num_columns *
sizeof(std::remove_reference_t<decltype(
std::declval<HistogramCuts>().MinValues())>::value_type);
total += (num_columns + 1) *
sizeof(std::remove_reference_t<decltype(
std::declval<HistogramCuts>().Ptrs())>::value_type);
peak = std::max(peak, total);
// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
return peak;
}
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
bst_row_t num_rows, size_t columns, size_t nnz, int device,
size_t num_cuts, bool has_weight) {
if (sketch_batch_num_elements == 0) {
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
// use up to 80% of available space
sketch_batch_num_elements = (dh::AvailableMemory(device) -
required_memory * 0.8);
}
return sketch_batch_num_elements;
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
@ -150,7 +179,7 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
// Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(),
EntryCompareOp());
detail::EntryCompareOp());
// Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
@ -160,6 +189,46 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
return a.index == b.index;
});
}
} // namespace detail
void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
SketchContainer *sketch_container, int num_cuts_per_feature,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(),
[] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
// add cuts into sketches
sorted_entries.clear();
sorted_entries.shrink_to_fit();
CHECK_EQ(sorted_entries.capacity(), 0);
CHECK_NE(cuts_ptr.Size(), 0);
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end,
@ -204,40 +273,53 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
d_temp_weights[idx] = weights[ridx + base_rowid];
});
}
SortByWeight(&alloc, &temp_weights, &sorted_entries);
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan,
{sorted_entries.data().get(), sorted_entries.size()},
num_columns);
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan);
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(),
[] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
// Extract cuts
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts_per_feature);
ExtractWeightedCuts(device, num_cuts_per_feature,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) {
// Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_);
sketch_batch_num_elements = SketchBatchNumElements(
size_t num_cuts_per_feature =
detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights);
dmat->Info().num_row_,
dmat->Info().num_col_,
dmat->Info().num_nonzero_,
device, num_cuts_per_feature, has_weights);
HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
dmat->Info().num_row_);
dmat->Info().num_row_, device);
dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) {
@ -261,8 +343,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
}
}
}
dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_);
sketch_container.MakeCuts(&cuts);
return cuts;
}
} // namespace common

View File

@ -1,5 +1,8 @@
/*!
* Copyright 2020 XGBoost contributors
*
* \brief Front end and utilities for GPU based sketching. Works on sliding window
* instead of stream.
*/
#ifndef COMMON_HIST_UTIL_CUH_
#define COMMON_HIST_UTIL_CUH_
@ -7,74 +10,15 @@
#include <thrust/host_vector.h>
#include "hist_util.h"
#include "threading_utils.h"
#include "quantile.cuh"
#include "device_helpers.cuh"
#include "timer.h"
#include "../data/device_adapter.cuh"
namespace xgboost {
namespace common {
using WQSketch = DenseCuts::WQSketch;
using SketchEntry = WQSketch::Entry;
/*!
* \brief A container that holds the device sketches across all
* sparse page batches which are distributed to different devices.
* As sketches are aggregated by column, the mutex guards
* multiple devices pushing sketch summary for the same column
* across distinct rows.
*/
struct SketchContainer {
std::vector<DenseCuts::WQSketch> sketches_; // NOLINT
static constexpr int kOmpNumColsParallelizeLimit = 1000;
static constexpr float kFactor = 8;
SketchContainer(int max_bin, size_t num_columns, size_t num_rows) {
// Initialize Sketches for this dmatrix
sketches_.resize(num_columns);
#pragma omp parallel for schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_columns; ++icol) { // NOLINT
sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin));
}
}
/**
* \brief Pushes cuts to the sketches.
*
* \param entries_per_column The entries per column.
* \param entries Vector of cuts from all columns, length
* entries_per_column * num_columns. \param column_scan Exclusive scan
* of column sizes. Used to detect cases where there are fewer entries than we
* have storage for.
*/
void Push(size_t entries_per_column,
const thrust::host_vector<SketchEntry>& entries,
const thrust::host_vector<size_t>& column_scan) {
#pragma omp parallel for schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < sketches_.size(); ++icol) {
size_t column_size = column_scan[icol + 1] - column_scan[icol];
if (column_size == 0) continue;
WQuantileSketch<bst_float, bst_float>::SummaryContainer summary;
size_t num_available_cuts =
std::min(size_t(entries_per_column), column_size);
summary.Reserve(num_available_cuts);
summary.MakeFromSorted(&entries[entries_per_column * icol],
num_available_cuts);
sketches_[icol].PushSummary(summary);
}
}
// Prevent copying/assigning/moving this as its internals can't be
// assigned/copied/moved
SketchContainer(const SketchContainer&) = delete;
SketchContainer(SketchContainer&& that) {
std::swap(sketches_, that.sketches_);
}
SketchContainer& operator=(const SketchContainer&) = delete;
SketchContainer& operator=(SketchContainer&&) = delete;
};
namespace detail {
struct EntryCompareOp {
__device__ bool operator()(const Entry& a, const Entry& b) {
if (a.index == b.index) {
@ -88,100 +32,105 @@ struct EntryCompareOp {
* \brief Extracts the cuts from sorted data.
*
* \param device The device.
* \param cuts Output cuts
* \param num_cuts_per_feature Number of cuts per feature.
* \param cuts_ptr Column pointers to CSC structured cuts
* \param sorted_data Sorted entries in segments of columns
* \param column_sizes_scan Describes the boundaries of column segments in
* sorted data
* \param column_sizes_scan Describes the boundaries of column segments in sorted data
* \param out_cuts Output cut values
*/
void ExtractCuts(int device,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts);
void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts);
// Count the entries in each column and exclusive scan
inline void GetColumnSizesScan(int device,
dh::caching_device_vector<size_t>* column_sizes_scan,
Span<const Entry> entries, size_t num_columns) {
column_sizes_scan->resize(num_columns + 1, 0);
auto d_column_sizes_scan = column_sizes_scan->data().get();
auto d_entries = entries.data();
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) {
auto& e = d_entries[idx];
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes_scan[e.index]),
static_cast<unsigned long long>(1)); // NOLINT
});
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}
/**
* \brief Extracts the cuts from sorted data, considering weights.
*
* \param device The device.
* \param cuts_ptr Column pointers to CSC structured cuts
* \param sorted_data Sorted entries in segments of columns.
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
* \param cuts Output cuts.
*/
void ExtractWeightedCutsSparse(int device,
common::Span<SketchContainer::OffsetT const> cuts_ptr,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts);
// For adapter.
// Get column size from adapter batch and for output cuts.
template <typename Iter>
void GetColumnSizesScan(int device, size_t num_columns,
void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature,
Iter batch_iter, data::IsValidFunctor is_valid,
size_t begin, size_t end,
HostDeviceVector<SketchContainer::OffsetT> *cuts_ptr,
dh::caching_device_vector<size_t>* column_sizes_scan) {
dh::XGBCachingDeviceAllocator<char> alloc;
column_sizes_scan->resize(num_columns + 1, 0);
cuts_ptr->SetDevice(device);
cuts_ptr->Resize(num_columns + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
auto d_column_sizes_scan = column_sizes_scan->data().get();
dh::LaunchN(device, end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&d_column_sizes_scan[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast<size_t>(1));
}
});
// Calculate cuts CSC pointer
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin());
}
inline size_t BytesPerElement(bool has_weight) {
inline size_t constexpr BytesPerElement(bool has_weight) {
// Double the memory usage for sorting. We need to assign weight for each element, so
// sizeof(float) is added to all elements.
return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2;
}
inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
size_t columns, int device,
size_t num_cuts, bool has_weight) {
if (sketch_batch_num_elements == 0) {
size_t bytes_per_element = BytesPerElement(has_weight);
size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry);
size_t bytes_num_columns = (columns + 1) * sizeof(size_t);
// use up to 80% of available space
sketch_batch_num_elements = (dh::AvailableMemory(device) -
bytes_cuts - bytes_num_columns) *
0.8 / bytes_per_element;
}
return sketch_batch_num_elements;
}
/* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements`
* directly if it's not 0.
*/
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
bst_row_t num_rows, size_t columns, size_t nnz, int device,
size_t num_cuts, bool has_weight);
// Compute number of sample cuts needed on local node to maintain accuracy
// We take more cuts than needed and then reduce them later
inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) {
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows);
}
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements = 0);
size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows);
/* \brief Estimate required memory for each sliding window.
*
* It's not precise as to obtain exact memory usage for sparse dataset we need to walk
* through the whole dataset first. Also if data is from host DMatrix, we copy the
* weight, group and offset on first batch, which is not considered in the function.
*
* \param num_rows Number of rows in this worker.
* \param num_columns Number of columns for this dataset.
* \param nnz Number of non-zero element. Put in something greater than rows *
* cols if nnz is unknown.
* \param num_bins Number of histogram bins.
* \param with_weights Whether weight is used, works the same for ranking and other models.
*
* \return The estimated bytes
*/
size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
size_t num_bins, bool with_weights);
// Count the valid entries in each column and copy them out.
template <typename AdapterBatch, typename BatchIter>
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
Range1d range, float missing,
size_t columns, int device,
thrust::host_vector<size_t>* host_column_sizes_scan,
size_t columns, size_t cuts_per_feature, int device,
HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan,
dh::caching_device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>(
@ -191,16 +140,12 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
});
data::IsValidFunctor is_valid(missing);
// Work out how many valid entries we have in each column
GetColumnSizesScan(device, columns,
GetColumnSizesScan(device, columns, cuts_per_feature,
batch_iter, is_valid,
range.begin(), range.end(),
cut_sizes_scan,
column_sizes_scan);
host_column_sizes_scan->resize(column_sizes_scan->size());
thrust::copy(column_sizes_scan->begin(), column_sizes_scan->end(),
host_column_sizes_scan->begin());
size_t num_valid = host_column_sizes_scan->back();
size_t num_valid = column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid);
dh::XGBCachingDeviceAllocator<char> alloc;
@ -208,6 +153,16 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
entry_iter + range.end(), sorted_entries->begin(), is_valid);
}
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries);
} // namespace detail
// Compute sketch on DMatrix.
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements = 0);
template <typename AdapterBatch>
void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
size_t begin, size_t end, float missing,
@ -215,41 +170,33 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
// Copy current subset of valid elements into temporary storage and sort
dh::caching_device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
thrust::host_vector<size_t> host_column_sizes_scan;
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, columns, device,
&host_column_sizes_scan,
&column_sizes_scan,
&sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing,
columns, num_cuts, device,
&cuts_ptr,
&column_sizes_scan,
&sorted_entries);
dh::XGBCachingDeviceAllocator<char> alloc;
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp());
sorted_entries.end(), detail::EntryCompareOp());
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
// Extract the cuts from all columns concurrently
dh::caching_device_vector<SketchEntry> cuts(columns * num_cuts);
ExtractCuts(device, num_cuts,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
detail::ExtractCutsSparse(device, d_cuts_ptr,
dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
// Push cuts into sketches stored in host memory
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts);
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries);
template <typename Batch>
void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
int num_cuts_per_feature,
@ -268,12 +215,13 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
[=] __device__(size_t idx) { return batch.GetElement(idx); });
dh::caching_device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan;
thrust::host_vector<size_t> host_column_sizes_scan;
MakeEntriesFromAdapter(batch, batch_iter,
{begin, end}, missing, columns, device,
&host_column_sizes_scan,
&column_sizes_scan,
&sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
detail::MakeEntriesFromAdapter(batch, batch_iter,
{begin, end}, missing,
columns, num_cuts_per_feature, device,
&cuts_ptr,
&column_sizes_scan,
&sorted_entries);
data::IsValidFunctor is_valid(missing);
dh::caching_device_vector<float> temp_weights(sorted_entries.size());
@ -297,6 +245,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
is_valid);
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
} else {
CHECK_EQ(batch.NumRows(), weights.size());
auto const weight_iter = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0lu),
[=]__device__(size_t idx) -> float {
@ -310,90 +259,114 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
}
SortByWeight(&alloc, &temp_weights, &sorted_entries);
// Extract cuts
dh::caching_device_vector<SketchEntry> cuts(columns * num_cuts_per_feature);
ExtractWeightedCuts(device, num_cuts_per_feature,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
// Extract cuts
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
// add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
template <typename AdapterT>
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
float missing,
size_t sketch_batch_num_elements = 0) {
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows());
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows());
CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
sketch_batch_num_elements = SketchBatchNumElements(
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
adapter->NumColumns(), adapter->DeviceIdx(), num_cuts, false);
adapter->NumRows(), adapter->NumColumns(), std::numeric_limits<size_t>::max(),
adapter->DeviceIdx(),
num_cuts_per_feature, false);
// Enforce single batch
CHECK(!adapter->Next());
HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(num_bins, adapter->NumColumns(),
adapter->NumRows());
adapter->NumRows(), adapter->DeviceIdx());
for (auto begin = 0ull; begin < batch.Size();
begin += sketch_batch_num_elements) {
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
auto const& batch = adapter->Value();
ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(),
begin, end, missing, &sketch_container, num_cuts);
begin, end, missing, &sketch_container, num_cuts_per_feature);
}
dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows());
sketch_container.MakeCuts(&cuts);
return cuts;
}
/*
* \brief Perform sketching on GPU.
*
* \param batch A batch from adapter.
* \param num_bins Bins per column.
* \param missing Floating point value that represents invalid value.
* \param sketch_container Container for output sketch.
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for
* testing.
*/
template <typename Batch>
void AdapterDeviceSketch(Batch batch, int num_bins,
float missing, int device,
SketchContainer* sketch_container,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols();
size_t num_cuts = RequiredSampleCuts(num_bins, num_rows);
sketch_batch_num_elements = SketchBatchNumElements(
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_cols, device, num_cuts, false);
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, device, num_cols,
begin, end, missing, sketch_container, num_cuts);
begin, end, missing, sketch_container, num_cuts_per_feature);
}
}
/*
* \brief Perform weighted sketching on GPU.
*
* When weight in info is empty, this function is equivalent to unweighted version.
*/
template <typename Batch>
void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
MetaInfo const& info,
float missing,
int device,
SketchContainer* sketch_container,
float missing, SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) {
if (info.weights_.Size() == 0) {
return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements);
}
size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols();
size_t num_cuts = RequiredSampleCuts(num_bins, num_rows);
sketch_batch_num_elements = SketchBatchNumElements(
size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements,
num_cols, device, num_cuts, true);
num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info,
num_cuts,
num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container);
}

View File

@ -167,7 +167,7 @@ class CutsBuilder {
/*! \brief Cut configuration for sparse dataset. */
class SparseCuts : public CutsBuilder {
/* \brief Distrbute columns to each thread according to number of entries. */
/* \brief Distribute columns to each thread according to number of entries. */
static std::vector<size_t> LoadBalance(SparsePage const& page, size_t const nthreads);
Monitor monitor_;

View File

@ -205,10 +205,10 @@ class HostDeviceVectorImpl {
// data is on the host
LazyResizeDevice(data_h_.size());
SetDevice();
dh::safe_cuda(cudaMemcpy(data_d_->data().get(),
data_h_.data(),
data_d_->size() * sizeof(T),
cudaMemcpyHostToDevice));
dh::safe_cuda(cudaMemcpyAsync(data_d_->data().get(),
data_h_.data(),
data_d_->size() * sizeof(T),
cudaMemcpyHostToDevice));
gpu_access_ = access;
}

572
src/common/quantile.cu Normal file
View File

@ -0,0 +1,572 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include <thrust/unique.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/binary_search.h>
#include <thrust/transform_scan.h>
#include <thrust/execution_policy.h>
#include <memory>
#include <utility>
#include "xgboost/span.h"
#include "quantile.h"
#include "quantile.cuh"
#include "hist_util.h"
#include "device_helpers.cuh"
#include "common.h"
namespace xgboost {
namespace common {
using WQSketch = DenseCuts::WQSketch;
using SketchEntry = WQSketch::Entry;
// Algorithm 4 in XGBoost's paper, using binary search to find i.
__device__ SketchEntry BinarySearchQuery(Span<SketchEntry const> const& entries, float rank) {
assert(entries.size() >= 2);
rank *= 2;
if (rank < entries.front().rmin + entries.front().rmax) {
return entries.front();
}
if (rank >= entries.back().rmin + entries.back().rmax) {
return entries.back();
}
auto begin = dh::MakeTransformIterator<float>(
entries.begin(), [=] __device__(SketchEntry const &entry) {
return entry.rmin + entry.rmax;
});
auto end = begin + entries.size();
auto i = thrust::upper_bound(thrust::seq, begin + 1, end - 1, rank) - begin - 1;
if (rank < entries[i].RMinNext() + entries[i+1].RMaxPrev()) {
return entries[i];
} else {
return entries[i+1];
}
}
template <typename T>
void CopyTo(Span<T> out, Span<T const> src) {
CHECK_EQ(out.size(), src.size());
dh::safe_cuda(cudaMemcpyAsync(out.data(), src.data(),
out.size_bytes(),
cudaMemcpyDefault));
}
// Compute the merge path.
common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
Span<SketchEntry const> const &d_x, Span<bst_row_t const> const &x_ptr,
Span<SketchEntry const> const &d_y, Span<bst_row_t const> const &y_ptr,
Span<SketchEntry> out, Span<bst_row_t> out_ptr) {
auto x_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple(
dh::MakeTransformIterator<bst_row_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(x_ptr, idx); }),
d_x.data()));
auto y_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple(
dh::MakeTransformIterator<bst_row_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(y_ptr, idx); }),
d_y.data()));
using Tuple = thrust::tuple<uint64_t, uint64_t>;
thrust::constant_iterator<uint64_t> a_ind_iter(0ul);
thrust::constant_iterator<uint64_t> b_ind_iter(1ul);
auto place_holder = thrust::make_constant_iterator<uint64_t>(0u);
auto x_merge_val_it =
thrust::make_zip_iterator(thrust::make_tuple(a_ind_iter, place_holder));
auto y_merge_val_it =
thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder));
dh::XGBCachingDeviceAllocator<Tuple> alloc;
static_assert(sizeof(Tuple) == sizeof(SketchEntry), "");
// We reuse the memory for storing merge path.
common::Span<Tuple> merge_path{reinterpret_cast<Tuple *>(out.data()), out.size()};
// Determine the merge path, 0 if element is from x, 1 if it's from y.
thrust::merge_by_key(
thrust::cuda::par(alloc), x_merge_key_it, x_merge_key_it + d_x.size(),
y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it,
y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(),
[=] __device__(auto const &l, auto const &r) -> bool {
auto l_column_id = thrust::get<0>(l);
auto r_column_id = thrust::get<0>(r);
if (l_column_id == r_column_id) {
return thrust::get<1>(l).value < thrust::get<1>(r).value;
}
return l_column_id < r_column_id;
});
// Compute output ptr
auto transform_it =
thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data()));
thrust::transform(
thrust::cuda::par(alloc), transform_it, transform_it + x_ptr.size(),
out_ptr.data(),
[] __device__(auto const& t) { return thrust::get<0>(t) + thrust::get<1>(t); });
// 0^th is the indicator, 1^th is placeholder
auto get_ind = []XGBOOST_DEVICE(Tuple const& t) { return thrust::get<0>(t); };
// 0^th is the counter for x, 1^th for y.
auto get_x = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<0>(t); };
auto get_y = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<1>(t); };
auto scan_key_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); });
auto scan_val_it = dh::MakeTransformIterator<Tuple>(
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
auto ind = get_ind(t); // == 0 if element is from x
// x_counter, y_counter
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
});
// Compute the index for both x and y (which of the element in a and b are used in each
// comparison) by scaning the binary merge path. Take output [(x_0, y_0), (x_0, y_1),
// ...] as an example, the comparison between (x_0, y_0) adds 1 step in the merge path.
// Asumming y_0 is less than x_0 so this step is torward the end of y. After the
// comparison, index of y is incremented by 1 from y_0 to y_1, and at the same time, y_0
// is landed into output as the first element in merge result. The scan result is the
// subscript of x and y.
thrust::exclusive_scan_by_key(
thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(),
scan_val_it, merge_path.data(),
thrust::make_tuple<uint64_t, uint64_t>(0ul, 0ul),
thrust::equal_to<size_t>{},
[=] __device__(Tuple const &l, Tuple const &r) -> Tuple {
return thrust::make_tuple(get_x(l) + get_x(r), get_y(l) + get_y(r));
});
return merge_path;
}
// Merge d_x and d_y into out. Because the final output depends on predicate (which
// summary does the output element come from) result by definition of merged rank. So we
// run it in 2 passes to obtain the merge path and then customize the standard merge
// algorithm.
void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
Span<bst_row_t const> const &x_ptr,
Span<SketchEntry const> const &d_y,
Span<bst_row_t const> const &y_ptr,
Span<SketchEntry> out,
Span<bst_row_t> out_ptr) {
dh::safe_cuda(cudaSetDevice(device));
CHECK_EQ(d_x.size() + d_y.size(), out.size());
CHECK_EQ(x_ptr.size(), out_ptr.size());
CHECK_EQ(y_ptr.size(), out_ptr.size());
auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr);
auto d_out = out;
dh::LaunchN(device, d_out.size(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(out_ptr, idx);
idx -= out_ptr[column_id];
auto d_x_column =
d_x.subspan(x_ptr[column_id], x_ptr[column_id + 1] - x_ptr[column_id]);
auto d_y_column =
d_y.subspan(y_ptr[column_id], y_ptr[column_id + 1] - y_ptr[column_id]);
auto d_out_column = d_out.subspan(
out_ptr[column_id], out_ptr[column_id + 1] - out_ptr[column_id]);
auto d_path_column = d_merge_path.subspan(
out_ptr[column_id], out_ptr[column_id + 1] - out_ptr[column_id]);
uint64_t a_ind, b_ind;
thrust::tie(a_ind, b_ind) = d_path_column[idx];
// Handle empty column. If both columns are empty, we should not get this column_id
// as result of binary search.
assert((d_x_column.size() != 0) || (d_y_column.size() != 0));
if (d_x_column.size() == 0) {
d_out_column[idx] = d_y_column[b_ind];
return;
}
if (d_y_column.size() == 0) {
d_out_column[idx] = d_x_column[a_ind];
return;
}
// Handle trailing elements.
assert(a_ind <= d_x_column.size());
if (a_ind == d_x_column.size()) {
// Trailing elements are from y because there's no more x to land.
auto y_elem = d_y_column[b_ind];
d_out_column[idx] = SketchEntry(y_elem.rmin + d_x_column.back().RMinNext(),
y_elem.rmax + d_x_column.back().rmax,
y_elem.wmin, y_elem.value);
return;
}
auto x_elem = d_x_column[a_ind];
assert(b_ind <= d_y_column.size());
if (b_ind == d_y_column.size()) {
d_out_column[idx] = SketchEntry(x_elem.rmin + d_y_column.back().RMinNext(),
x_elem.rmax + d_y_column.back().rmax,
x_elem.wmin, x_elem.value);
return;
}
auto y_elem = d_y_column[b_ind];
/* Merge procedure. See A.3 merge operation eq (26) ~ (28). The trick to interpret
it is rewriting the symbols on both side of equality. Take eq (26) as an example:
Expand it according to definition of extended rank then rewrite it into:
If $k_i$ is the $i$ element in output and \textbf{comes from $D_1$}:
r_\bar{D}(k_i) = r_{\bar{D_1}}(k_i) + w_{\bar{{D_1}}}(k_i) +
[r_{\bar{D_2}}(x_i) + w_{\bar{D_2}}(x_i)]
Where $x_i$ is the largest element in $D_2$ that's less than $k_i$. $k_i$ can be
used in $D_1$ as it's since $k_i \in D_1$. Other 2 equations can be applied
similarly with $k_i$ comes from different $D$. just use different symbol on
different source of summary.
*/
assert(idx < d_out_column.size());
if (x_elem.value == y_elem.value) {
d_out_column[idx] =
SketchEntry{x_elem.rmin + y_elem.rmin, x_elem.rmax + y_elem.rmax,
x_elem.wmin + y_elem.wmin, x_elem.value};
} else if (x_elem.value < y_elem.value) {
// elem from x is landed. yprev_min is the element in D_2 that's 1 rank less than
// x_elem if we put x_elem in D_2.
float yprev_min = b_ind == 0 ? 0.0f : d_y_column[b_ind - 1].RMinNext();
// rmin should be equal to x_elem.rmin + x_elem.wmin + yprev_min. But for
// implementation, the weight is stored in a separated field and we compute the
// extended definition on the fly when needed.
d_out_column[idx] =
SketchEntry{x_elem.rmin + yprev_min, x_elem.rmax + y_elem.RMaxPrev(),
x_elem.wmin, x_elem.value};
} else {
// elem from y is landed.
float xprev_min = a_ind == 0 ? 0.0f : d_x_column[a_ind - 1].RMinNext();
d_out_column[idx] =
SketchEntry{xprev_min + y_elem.rmin, x_elem.RMaxPrev() + y_elem.rmax,
y_elem.wmin, y_elem.value};
}
});
}
void SketchContainer::Push(common::Span<OffsetT const> cuts_ptr,
dh::caching_device_vector<SketchEntry>* entries) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
// Copy or merge the new cuts, pruning is performed during `MakeCuts`.
if (this->Current().size() == 0) {
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
// See thrust issue 1030, THRUST_CPP_DIALECT is not correctly defined so
// move constructor is not used.
this->Current().swap(*entries);
CHECK_EQ(entries->size(), 0);
auto d_cuts_ptr = this->columns_ptr_.DevicePointer();
thrust::copy(thrust::device, cuts_ptr.data(),
cuts_ptr.data() + cuts_ptr.size(), d_cuts_ptr);
} else {
auto d_entries = dh::ToSpan(*entries);
this->Merge(cuts_ptr, d_entries);
this->FixError();
}
CHECK_NE(this->columns_ptr_.Size(), 0);
timer_.Stop(__func__);
}
size_t SketchContainer::Unique() {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();
d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(),
detail::SketchUnique{});
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());
this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
}
void SketchContainer::Prune(size_t to) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->Unique();
OffsetT to_total = 0;
HostDeviceVector<OffsetT> new_columns_ptr{to_total};
for (bst_feature_t i = 0; i < num_columns_; ++i) {
size_t length = this->Column(i).size();
length = std::min(length, to);
to_total += length;
new_columns_ptr.HostVector().emplace_back(to_total);
}
new_columns_ptr.SetDevice(device_);
this->Other().resize(to_total);
auto d_columns_ptr_in = this->columns_ptr_.ConstDeviceSpan();
auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan();
auto out = dh::ToSpan(this->Other());
auto in = dh::ToSpan(this->Current());
dh::LaunchN(0, to_total, [=] __device__(size_t idx) {
size_t column_id = dh::SegmentId(d_columns_ptr_out, idx);
auto out_column = out.subspan(d_columns_ptr_out[column_id],
d_columns_ptr_out[column_id + 1] -
d_columns_ptr_out[column_id]);
auto in_column = in.subspan(d_columns_ptr_in[column_id],
d_columns_ptr_in[column_id + 1] -
d_columns_ptr_in[column_id]);
idx -= d_columns_ptr_out[column_id];
// Input has lesser columns than `to`, just copy them to the output. This is correct
// as the new output size is calculated based on both the size of `to` and current
// column.
if (in_column.size() <= to) {
out_column[idx] = in_column[idx];
return;
}
// 1 thread for each output. See A.4 for detail.
auto entries = in_column;
auto d_out = out_column;
if (idx == 0) {
d_out.front() = entries.front();
return;
}
if (idx == to - 1) {
d_out.back() = entries.back();
return;
}
float w = entries.back().rmin - entries.front().rmax;
assert(w != 0);
auto budget = static_cast<float>(d_out.size());
assert(budget != 0);
auto q = ((idx * w) / (to - 1) + entries.front().rmax);
d_out[idx] = BinarySearchQuery(entries, q);
});
this->columns_ptr_.HostVector() = new_columns_ptr.HostVector();
this->Alternate();
timer_.Stop(__func__);
}
void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
Span<SketchEntry const> that) {
dh::safe_cuda(cudaSetDevice(device_));
timer_.Start(__func__);
if (this->Current().size() == 0) {
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size());
CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1);
thrust::copy(thrust::device, d_that_columns_ptr.data(),
d_that_columns_ptr.data() + d_that_columns_ptr.size(),
this->columns_ptr_.DevicePointer());
auto total = this->columns_ptr_.HostVector().back();
this->Current().resize(total);
CopyTo(dh::ToSpan(this->Current()), that);
timer_.Stop(__func__);
return;
}
this->Other().resize(this->Current().size() + that.size());
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
HostDeviceVector<OffsetT> new_columns_ptr;
new_columns_ptr.SetDevice(device_);
new_columns_ptr.Resize(this->ColumnsPtr().size());
MergeImpl(device_, this->Data(), this->ColumnsPtr(),
that, d_that_columns_ptr,
dh::ToSpan(this->Other()), new_columns_ptr.DeviceSpan());
this->columns_ptr_ = std::move(new_columns_ptr);
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
CHECK_EQ(new_columns_ptr.Size(), 0);
this->Alternate();
timer_.Stop(__func__);
}
void SketchContainer::FixError() {
dh::safe_cuda(cudaSetDevice(device_));
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
auto in = dh::ToSpan(this->Current());
dh::LaunchN(device_, in.size(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_columns_ptr, idx);
auto in_column = in.subspan(d_columns_ptr[column_id],
d_columns_ptr[column_id + 1] -
d_columns_ptr[column_id]);
idx -= d_columns_ptr[column_id];
float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin;
if (in_column[idx].rmin < prev_rmin) {
in_column[idx].rmin = prev_rmin;
}
float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax;
if (in_column[idx].rmax < prev_rmax) {
in_column[idx].rmax = prev_rmax;
}
float rmin_next = in_column[idx].RMinNext();
if (in_column[idx].rmax < rmin_next) {
in_column[idx].rmax = rmin_next;
}
});
}
void SketchContainer::AllReduce() {
dh::safe_cuda(cudaSetDevice(device_));
auto world = rabit::GetWorldSize();
if (world == 1) {
return;
}
timer_.Start(__func__);
if (!reducer_) {
reducer_ = std::make_unique<dh::AllReducer>();
reducer_->Init(device_);
}
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
rabit::Allreduce<rabit::op::Sum>(&global_sum_rows, 1);
size_t intermediate_num_cuts =
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
this->Prune(intermediate_num_cuts);
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
size_t n = d_columns_ptr.size();
rabit::Allreduce<rabit::op::Max>(&n, 1);
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
// Get the columns ptr from all workers
dh::device_vector<SketchContainer::OffsetT> gathered_ptrs;
gathered_ptrs.resize(d_columns_ptr.size() * world, 0);
size_t rank = rabit::GetRank();
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(),
gathered_ptrs.size());
// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
reducer_->AllGather(this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths,
&recvbuf);
reducer_->Synchronize();
// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);
std::vector<Span<SketchEntry>> allworkers;
offset = 0;
for (int32_t i = 0; i < world; ++i) {
size_t length_as_bytes = recv_lengths.at(i);
auto raw = s_recvbuf.subspan(offset, length_as_bytes);
auto sketch = Span<SketchEntry>(reinterpret_cast<SketchEntry *>(raw.data()),
length_as_bytes / sizeof(SketchEntry));
allworkers.emplace_back(sketch);
offset += length_as_bytes;
}
// Merge them into a new sketch.
SketchContainer new_sketch(num_bins_, this->num_columns_, global_sum_rows,
this->device_);
for (size_t i = 0; i < allworkers.size(); ++i) {
auto worker = allworkers[i];
auto worker_ptr =
dh::ToSpan(gathered_ptrs)
.subspan(i * d_columns_ptr.size(), d_columns_ptr.size());
new_sketch.Merge(worker_ptr, worker);
new_sketch.FixError();
}
*this = std::move(new_sketch);
timer_.Stop(__func__);
}
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
p_cuts->min_vals_.Resize(num_columns_);
// Sync between workers.
this->AllReduce();
// Prune to final number of bins.
this->Prune(num_bins_ + 1);
this->Unique();
this->FixError();
// Set up inputs
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
p_cuts->min_vals_.SetDevice(device_);
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
auto in_cut_values = dh::ToSpan(this->Current());
// Set up output ptr
p_cuts->cut_ptrs_.SetDevice(device_);
auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
h_out_columns_ptr.clear();
h_out_columns_ptr.push_back(0);
for (bst_feature_t i = 0; i < num_columns_; ++i) {
h_out_columns_ptr.push_back(
std::min(static_cast<size_t>(std::max(static_cast<size_t>(1ul),
this->Column(i).size())),
static_cast<size_t>(num_bins_)));
}
std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(),
h_out_columns_ptr.begin());
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
// Set up output cuts
size_t total_bins = h_out_columns_ptr.back();
p_cuts->cut_values_.SetDevice(device_);
p_cuts->cut_values_.Resize(total_bins);
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
dh::LaunchN(0, total_bins, [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
auto in_column = in_cut_values.subspan(d_in_columns_ptr[column_id],
d_in_columns_ptr[column_id + 1] -
d_in_columns_ptr[column_id]);
auto out_column = out_cut_values.subspan(d_out_columns_ptr[column_id],
d_out_columns_ptr[column_id + 1] -
d_out_columns_ptr[column_id]);
idx -= d_out_columns_ptr[column_id];
if (in_column.size() == 0) {
// If the column is empty, we push a dummy value. It won't affect training as the
// column is empty, trees cannot split on it. This is just to be consistent with
// rest of the library.
if (idx == 0) {
d_min_values[column_id] = kRtEps;
out_column[0] = kRtEps;
assert(out_column.size() == 1);
}
return;
}
// First thread is responsible for setting min values.
if (idx == 0) {
auto mval = in_column[idx].value;
d_min_values[column_id] = mval - (fabs(mval) + 1e-5);
}
// Last thread is responsible for setting a value that's greater than other cuts.
if (idx == out_column.size() - 1) {
const bst_float cpt = in_column.back().value;
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5);
out_column[idx] = last;
return;
}
assert(idx+1 < in_column.size());
out_column[idx] = in_column[idx+1].value;
});
timer_.Stop(__func__);
}
} // namespace common
} // namespace xgboost

141
src/common/quantile.cuh Normal file
View File

@ -0,0 +1,141 @@
#ifndef XGBOOST_COMMON_QUANTILE_CUH_
#define XGBOOST_COMMON_QUANTILE_CUH_
#include <memory>
#include "xgboost/span.h"
#include "device_helpers.cuh"
#include "quantile.h"
#include "timer.h"
namespace xgboost {
namespace common {
class HistogramCuts;
using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry;
/*!
* \brief A container that holds the device sketches. Sketching is performed per-column,
* but fused into single operation for performance.
*/
class SketchContainer {
public:
static constexpr float kFactor = WQSketch::kFactor;
using OffsetT = bst_row_t;
static_assert(sizeof(OffsetT) == sizeof(size_t), "Wrong type for sketch element offset.");
private:
Monitor timer_;
std::unique_ptr<dh::AllReducer> reducer_;
bst_row_t num_rows_;
bst_feature_t num_columns_;
int32_t num_bins_;
int32_t device_;
// Double buffer as neither prune nor merge can be performed inplace.
dh::caching_device_vector<SketchEntry> entries_a_;
dh::caching_device_vector<SketchEntry> entries_b_;
bool current_buffer_ {true};
// The container is just a CSC matrix.
HostDeviceVector<OffsetT> columns_ptr_;
dh::caching_device_vector<SketchEntry>& Current() {
if (current_buffer_) {
return entries_a_;
} else {
return entries_b_;
}
}
dh::caching_device_vector<SketchEntry>& Other() {
if (!current_buffer_) {
return entries_a_;
} else {
return entries_b_;
}
}
dh::caching_device_vector<SketchEntry> const& Current() const {
return const_cast<SketchContainer*>(this)->Current();
}
dh::caching_device_vector<SketchEntry> const& Other() const {
return const_cast<SketchContainer*>(this)->Other();
}
void Alternate() {
current_buffer_ = !current_buffer_;
}
// Get the span of one column.
Span<SketchEntry> Column(bst_feature_t i) {
auto data = dh::ToSpan(this->Current());
auto h_ptr = columns_ptr_.ConstHostSpan();
auto c = data.subspan(h_ptr[i], h_ptr[i+1] - h_ptr[i]);
return c;
}
public:
/* \breif GPU quantile structure, with sketch data for each columns.
*
* \param max_bin Maximum number of bins per columns
* \param num_columns Total number of columns in dataset.
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID.
*/
SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) :
num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1);
timer_.Init(__func__);
}
/* \brief Return GPU ID for this container. */
int32_t DeviceIdx() const { return device_; }
/* \brief Removes all the duplicated elements in quantile structure. */
size_t Unique();
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
void FixError();
/* \brief Push a CSC structured cut matrix. */
void Push(common::Span<OffsetT const> cuts_ptr,
dh::caching_device_vector<SketchEntry>* entries);
/* \brief Prune the quantile structure.
*
* \param to The maximum size of pruned quantile. If the size of quantile structure is
* already less than `to`, then no operation is performed.
*/
void Prune(size_t to);
/* \brief Merge another set of sketch.
* \param that columns of other.
*/
void Merge(Span<OffsetT const> that_columns_ptr,
Span<SketchEntry const> that);
/* \brief Merge quantiles from other GPU workers. */
void AllReduce();
/* \brief Create the final histogram cut values. */
void MakeCuts(HistogramCuts* cuts);
Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};
}
Span<OffsetT const> ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); }
SketchContainer(SketchContainer&&) = default;
SketchContainer& operator=(SketchContainer&&) = default;
SketchContainer(const SketchContainer&) = delete;
SketchContainer& operator=(const SketchContainer&) = delete;
};
namespace detail {
struct SketchUnique {
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
}
};
} // anonymous detail
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_CUH_

View File

@ -55,6 +55,14 @@ struct WQSummary {
XGBOOST_DEVICE inline RType RMaxPrev() const {
return rmax - wmin;
}
friend std::ostream& operator<<(std::ostream& os, Entry const& e) {
os << "rmin: " << e.rmin << ", "
<< "rmax: " << e.rmax << ", "
<< "wmin: " << e.wmin << ", "
<< "value: " << e.value;
return os;
}
};
/*! \brief input data queue before entering the summary */
struct Queue {
@ -184,14 +192,14 @@ struct WQSummary {
}
}
}
/*!
* \brief set current summary to be pruned summary of src
* assume data field is already allocated to be at least maxsize
* \param src source summary
* \param maxsize size we can afford in the pruned sketch
*/
inline void SetPrune(const WQSummary &src, size_t maxsize) {
void SetPrune(const WQSummary &src, size_t maxsize) {
if (src.size <= maxsize) {
this->CopyFrom(src); return;
}
@ -454,6 +462,9 @@ struct WXQSummary : public WQSummary<DType, RType> {
*/
template<typename DType, typename RType, class TSummary>
class QuantileSketchTemplate {
public:
static float constexpr kFactor = 8.0;
public:
/*! \brief type of summary type */
using Summary = TSummary;

0
src/common/threading_utils.h Executable file → Normal file
View File

View File

@ -57,80 +57,52 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t nnz = 0;
// Sketch for all batches.
iter.Reset();
common::HistogramCuts cuts;
common::DenseCuts dense_cuts(&cuts);
std::vector<common::SketchContainer> sketch_containers;
size_t batches = 0;
size_t accumulated_rows = 0;
bst_feature_t cols = 0;
int32_t device = -1;
while (iter.Next()) {
auto device = proxy->DeviceIdx();
device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device));
if (cols == 0) {
cols = num_cols();
} else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
}
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows());
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
auto* p_sketch = &sketch_containers.back();
if (proxy->Info().weights_.Size() != 0) {
proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(),
missing, device, p_sketch);
});
} else {
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketch(value, batch_param_.max_bin, missing,
device, p_sketch);
});
}
proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(), missing, p_sketch);
});
auto batch_rows = num_rows();
accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size());
row_stride =
std::max(row_stride, Dispatch(proxy, [=](auto const& value) {
return GetRowCounts(value, row_counts_span, device, missing);
}));
nnz += thrust::reduce(thrust::cuda::par(alloc),
row_counts.begin(), row_counts.end());
row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
return GetRowCounts(value, row_counts_span,
device, missing);
}));
nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
row_counts.end());
batches++;
}
// Merging multiple batches for each column
std::vector<common::WQSketch::SummaryContainer> summary_array(cols);
size_t intermediate_num_cuts = std::min(
accumulated_rows, static_cast<size_t>(batch_param_.max_bin *
common::SketchContainer::kFactor));
size_t nbytes =
common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
for (omp_ulong c = 0; c < cols; ++c) {
for (auto& sketch_batch : sketch_containers) {
common::WQSketch::SummaryContainer summary;
sketch_batch.sketches_.at(c).GetSummary(&summary);
sketch_batch.sketches_.at(c).Init(0, 1);
summary_array.at(c).Reduce(summary, nbytes);
}
common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
for (auto const& sketch : sketch_containers) {
final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
final_sketch.FixError();
}
sketch_containers.clear();
sketch_containers.shrink_to_fit();
// Build the final summary.
std::vector<common::WQSketch> sketches(cols);
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
for (omp_ulong c = 0; c < cols; ++c) {
sketches.at(c).Init(
accumulated_rows,
1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin));
sketches.at(c).PushSummary(summary_array.at(c));
}
dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows);
summary_array.clear();
common::HistogramCuts cuts;
final_sketch.MakeCuts(&cuts);
this->info_.num_col_ = cols;
this->info_.num_row_ = accumulated_rows;

View File

@ -5,6 +5,7 @@
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
#include "../helpers.h"
#include "gtest/gtest.h"
@ -14,3 +15,128 @@ TEST(SumReduce, Test) {
ASSERT_NEAR(sum, 100.0f, 1e-5);
}
void TestAtomicSizeT() {
size_t constexpr kThreads = 235;
dh::device_vector<size_t> out(1, 0);
auto d_out = dh::ToSpan(out);
dh::LaunchN(0, kThreads, [=]__device__(size_t idx){
atomicAdd(&d_out[0], static_cast<size_t>(1));
});
ASSERT_EQ(out[0], kThreads);
}
TEST(AtomicAdd, SizeT) {
TestAtomicSizeT();
}
TEST(SegmentedUnique, Basic) {
std::vector<float> values{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.62448811531066895f, 0.4f};
std::vector<size_t> segments{0, 3, 6};
thrust::device_vector<float> d_values(values);
thrust::device_vector<xgboost::bst_feature_t> d_segments{segments};
thrust::device_vector<xgboost::bst_feature_t> d_segs_out(d_segments.size());
thrust::device_vector<float> d_vals_out(d_values.size());
size_t n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(),
d_segs_out.data().get(), d_vals_out.data().get(),
thrust::equal_to<float>{});
CHECK_EQ(n_uniques, 5);
std::vector<float> values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f};
for (auto i = 0 ; i < values_sol.size(); i ++) {
ASSERT_EQ(d_vals_out[i], values_sol[i]);
}
std::vector<xgboost::bst_feature_t> segments_sol{0, 3, 5};
for (size_t i = 0; i < d_segments.size(); ++i) {
ASSERT_EQ(segments_sol[i], d_segs_out[i]);
}
d_segments[1] = 4;
d_segments[2] = 6;
n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(),
d_segs_out.data().get(), d_vals_out.data().get(),
thrust::equal_to<float>{});
ASSERT_EQ(n_uniques, values.size());
for (auto i = 0 ; i < values.size(); i ++) {
ASSERT_EQ(d_vals_out[i], values[i]);
}
}
namespace {
using SketchEntry = xgboost::common::WQSummary<float, float>::Entry;
struct SketchUnique {
bool __device__ operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
}
};
struct IsSorted {
bool __device__ operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value < b.value;
}
};
} // namespace
namespace xgboost {
namespace common {
void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_duplicated) {
std::vector<bst_feature_t> segments{0, static_cast<bst_feature_t>(values.size())};
thrust::device_vector<SketchEntry> d_values(values);
thrust::device_vector<bst_feature_t> d_segments(segments);
thrust::device_vector<bst_feature_t> d_segments_out(segments.size());
size_t n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(), d_values.data().get(),
d_values.data().get() + d_values.size(), d_segments_out.data().get(), d_values.data().get(),
SketchUnique{});
ASSERT_EQ(n_uniques, values.size() - n_duplicated);
ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(),
d_values.begin() + n_uniques, IsSorted{}));
ASSERT_EQ(segments.at(0), d_segments_out[0]);
ASSERT_EQ(segments.at(1), d_segments_out[1] + n_duplicated);
}
TEST(SegmentedUnique, Regression) {
{
std::vector<SketchEntry> values{{3149, 3150, 1, 0.62392902374267578},
{3151, 3152, 1, 0.62418866157531738},
{3152, 3153, 1, 0.62419462203979492},
{3153, 3154, 1, 0.62431186437606812},
{3154, 3155, 1, 0.6244881153106689453125},
{3155, 3156, 1, 0.6244881153106689453125},
{3155, 3156, 1, 0.6244881153106689453125},
{3155, 3156, 1, 0.6244881153106689453125},
{3157, 3158, 1, 0.62552797794342041},
{3158, 3159, 1, 0.6256556510925293},
{3159, 3160, 1, 0.62571090459823608},
{3160, 3161, 1, 0.62577134370803833}};
TestSegmentedUniqueRegression(values, 3);
}
{
std::vector<SketchEntry> values{{3149, 3150, 1, 0.62392902374267578},
{3151, 3152, 1, 0.62418866157531738},
{3152, 3153, 1, 0.62419462203979492},
{3153, 3154, 1, 0.62431186437606812},
{3154, 3155, 1, 0.6244881153106689453125},
{3157, 3158, 1, 0.62552797794342041},
{3158, 3159, 1, 0.6256556510925293},
{3159, 3160, 1, 0.62571090459823608},
{3160, 3161, 1, 0.62577134370803833}};
TestSegmentedUniqueRegression(values, 0);
}
{
std::vector<SketchEntry> values;
TestSegmentedUniqueRegression(values, 0);
}
}
} // namespace common
} // namespace xgboost

View File

@ -30,11 +30,12 @@ HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
builder.Build(&dmat, num_bins);
return cuts;
}
TEST(HistUtil, DeviceSketch) {
int num_rows = 5;
int num_columns = 1;
int num_bins = 4;
std::vector<float> x = {1.0, 2.0, 3.0, 4.0, 5.0};
std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f};
int num_rows = x.size();
auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
@ -47,26 +48,6 @@ TEST(HistUtil, DeviceSketch) {
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
}
// Duplicate this function from hist_util.cu so we don't have to expose it in
// header
size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows);
}
size_t BytesRequiredForTest(size_t num_rows, size_t num_columns, size_t num_bins,
bool with_weights) {
size_t bytes_num_elements = BytesPerElement(with_weights) * num_rows * num_columns;
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
sizeof(DenseCuts::WQSketch::Entry);
// divide by 2 is because the memory quota used in sorting is reused for storing cuts.
return bytes_num_elements / 2 + bytes_cuts;
}
TEST(HistUtil, DeviceSketchMemory) {
int num_columns = 100;
int num_rows = 1000;
@ -77,15 +58,15 @@ TEST(HistUtil, DeviceSketchMemory) {
dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}});
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
size_t bytes_constant = 1000;
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
ConsoleLogger::Configure({{"verbosity", "0"}});
}
TEST(HistUtil, DeviceSketchMemoryWeights) {
TEST(HistUtil, DeviceSketchWeightsMemory) {
int num_columns = 100;
int num_rows = 1000;
int num_bins = 256;
@ -98,7 +79,8 @@ TEST(HistUtil, DeviceSketchMemoryWeights) {
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true);
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, true);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
}
@ -118,7 +100,7 @@ TEST(HistUtil, DeviceSketchDeterminism) {
}
}
TEST(HistUtil, DeviceSketchCategorical) {
TEST(HistUtil, DeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256;
int sizes[] = {25, 100, 1000};
@ -231,11 +213,10 @@ template <typename Adapter>
void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows,
DMatrix* dmat) {
common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows);
SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
0, &sketch_container);
common::DenseCuts dense_cuts(&batched_cuts);
dense_cuts.Init(&sketch_container.sketches_, num_bins, num_rows);
&sketch_container);
sketch_container.MakeCuts(&batched_cuts);
ValidateCuts(batched_cuts, dmat, num_bins);
}
@ -275,12 +256,13 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
std::numeric_limits<float>::quiet_NaN());
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_constant = 1000;
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
}
TEST(HistUtil, AdapterSketchBatchMemory) {
TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
int num_columns = 100;
int num_rows = 1000;
int num_bins = 256;
@ -291,17 +273,19 @@ TEST(HistUtil, AdapterSketchBatchMemory) {
dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}});
common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows);
SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
0, &sketch_container);
&sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_constant = 1000;
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
}
TEST(HistUtil, AdapterSketchBatchWeightedMemory) {
TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
int num_columns = 100;
int num_rows = 1000;
int num_bins = 256;
@ -316,12 +300,15 @@ TEST(HistUtil, AdapterSketchBatchWeightedMemory) {
dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}});
common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows);
SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), 0,
std::numeric_limits<float>::quiet_NaN(),
&sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true);
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, true);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
}
@ -462,13 +449,11 @@ void TestAdapterSketchFromWeights(bool with_group) {
data::CupyAdapter adapter(m);
auto const& batch = adapter.Value();
SketchContainer sketch_container(kBins, kCols, kRows);
SketchContainer sketch_container(kBins, kCols, kRows, 0);
AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
0,
&sketch_container);
common::HistogramCuts cuts;
common::DenseCuts dense_cuts(&cuts);
dense_cuts.Init(&sketch_container.sketches_, kBins, kRows);
sketch_container.MakeCuts(&cuts);
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
if (with_group) {

View File

@ -117,7 +117,7 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx,
// First and last bin can have smaller
for (auto& kv : bin_weights) {
EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight),
ASSERT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight),
allowable_error);
}
}
@ -189,7 +189,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
// Collect data into columns
std::vector<std::vector<float>> columns(dmat->Info().num_col_);
for (auto& batch : dmat->GetBatches<SparsePage>()) {
CHECK_GT(batch.Size(), 0);
ASSERT_GT(batch.Size(), 0);
for (auto i = 0ull; i < batch.Size(); i++) {
for (auto e : batch[i]) {
columns[e.index].push_back(e.fvalue);

0
tests/cpp/common/test_partition_builder.cc Executable file → Normal file
View File

View File

@ -0,0 +1,480 @@
#include <gtest/gtest.h>
#include "../helpers.h"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh"
namespace xgboost {
namespace common {
TEST(GPUQuantile, Basic) {
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
SketchContainer sketch(kBins, kCols, kRows, 0);
dh::caching_device_vector<SketchEntry> entries;
dh::device_vector<bst_row_t> cuts_ptr(kCols+1);
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
// Push empty
sketch.Push(dh::ToSpan(cuts_ptr), &entries);
ASSERT_EQ(sketch.Data().size(), 0);
}
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
std::vector<int32_t> seeds(4);
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist(3, 1000);
std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); });
std::vector<size_t> bins(8);
for (size_t i = 0; i < bins.size() - 1; ++i) {
bins[i] = i * 35 + 2;
}
bins.back() = rows + 80; // provide a bin number greater than rows.
std::vector<MetaInfo> infos(2);
auto& h_weights = infos.front().weights_.HostVector();
h_weights.resize(rows);
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
for (auto seed : seeds) {
for (auto n_bin : bins) {
for (auto const& info : infos) {
fn(seed, n_bin, info);
}
}
}
}
void TestSketchUnique(float sparsity) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) {
SketchContainer sketch(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity}
.Seed(seed)
.Device(0)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch);
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<size_t> cut_sizes_scan;
auto batch = adapter.Value();
data::IsValidFunctor is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto end = kCols * kRows;
detail::GetColumnSizesScan(0, kCols, n_cuts, batch_iter, is_valid, 0, end,
&cut_sizes_scan, &column_sizes_scan);
auto const& cut_sizes = cut_sizes_scan.HostVector();
if (sparsity == 0) {
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
} else {
ASSERT_EQ(sketch.Data().size(), cut_sizes.back());
}
sketch.Unique();
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
sketch.Data().data() + sketch.Data().size(),
detail::SketchUnique{}));
});
}
TEST(GPUQuantile, Unique) {
TestSketchUnique(0);
TestSketchUnique(0.5);
}
// if with_error is true, the test tolerates floating point error
void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
Span<bst_row_t const> d_columns_ptr, bool with_error = false) {
dh::LaunchN(device, in.size(), [=]XGBOOST_DEVICE(size_t idx) {
auto column_id = dh::SegmentId(d_columns_ptr, idx);
auto in_column = in.subspan(d_columns_ptr[column_id],
d_columns_ptr[column_id + 1] -
d_columns_ptr[column_id]);
auto constexpr kEps = 1e-6f;
idx -= d_columns_ptr[column_id];
float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin;
float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax;
float rmin_next = in_column[idx].RMinNext();
if (with_error) {
SPAN_CHECK(in_column[idx].rmin + in_column[idx].rmin * kEps >= prev_rmin);
SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= prev_rmax);
SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= rmin_next);
} else {
SPAN_CHECK(in_column[idx].rmin >= prev_rmin);
SPAN_CHECK(in_column[idx].rmax >= prev_rmax);
SPAN_CHECK(in_column[idx].rmax >= rmin_next);
}
});
// Force sync to terminate current test instead of a later one.
dh::DebugSyncDevice(__FILE__, __LINE__);
}
TEST(GPUQuantile, Prune) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
SketchContainer sketch(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch);
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
sketch.Prune(n_bins);
if (n_bins <= kRows) {
ASSERT_EQ(sketch.Data().size(), n_bins * kCols);
} else {
// LE because kRows * kCols is pushed into sketch, after removing duplicated entries
// we might not have that much inputs for prune.
ASSERT_LE(sketch.Data().size(), kRows * kCols);
}
// This is not necessarily true for all inputs without calling unique after
// prune.
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
sketch.Data().data() + sketch.Data().size(),
detail::SketchUnique{}));
TestQuantileElemRank(0, sketch.Data(), sketch.ColumnsPtr());
});
}
TEST(GPUQuantile, MergeEmpty) {
constexpr size_t kRows = 1000, kCols = 100;
size_t n_bins = 10;
SketchContainer sketch_0(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage_0;
std::string interface_str_0 =
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketch(adapter_0.Value(), n_bins,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
std::vector<SketchEntry> entries_before(sketch_0.Data().size());
dh::CopyDeviceSpanToVector(&entries_before, sketch_0.Data());
std::vector<bst_row_t> ptrs_before(sketch_0.ColumnsPtr().size());
dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr());
thrust::device_vector<size_t> columns_ptr(kCols + 1);
// Merge an empty sketch
sketch_0.Merge(dh::ToSpan(columns_ptr), Span<SketchEntry>{});
std::vector<SketchEntry> entries_after(sketch_0.Data().size());
dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data());
std::vector<bst_row_t> ptrs_after(sketch_0.ColumnsPtr().size());
dh::CopyDeviceSpanToVector(&ptrs_after, sketch_0.ColumnsPtr());
CHECK_EQ(entries_before.size(), entries_after.size());
CHECK_EQ(ptrs_before.size(), ptrs_after.size());
for (size_t i = 0; i < entries_before.size(); ++i) {
CHECK_EQ(entries_before[i].value, entries_after[i].value);
CHECK_EQ(entries_before[i].rmin, entries_after[i].rmin);
CHECK_EQ(entries_before[i].rmax, entries_after[i].rmax);
CHECK_EQ(entries_before[i].wmin, entries_after[i].wmin);
}
for (size_t i = 0; i < ptrs_before.size(); ++i) {
CHECK_EQ(ptrs_before[i], ptrs_after[i]);
}
}
TEST(GPUQuantile, MergeBasic) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
SketchContainer sketch_0(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage_0;
std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0);
HostDeviceVector<float> storage_1;
std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_1);
data::CupyAdapter adapter_1(interface_str_1);
AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
size_t size_before_merge = sketch_0.Data().size();
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
if (info.weights_.Size() != 0) {
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr(), true);
sketch_0.FixError();
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr(), false);
} else {
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr());
}
auto columns_ptr = sketch_0.ColumnsPtr();
std::vector<bst_row_t> h_columns_ptr(columns_ptr.size());
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
sketch_0.Unique();
ASSERT_TRUE(
thrust::is_sorted(thrust::device, sketch_0.Data().data(),
sketch_0.Data().data() + sketch_0.Data().size(),
detail::SketchUnique{}));
});
}
void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
MetaInfo info;
int32_t seed = 0;
SketchContainer sketch_0(n_bins, cols, rows, 0);
HostDeviceVector<float> storage_0;
std::string interface_str_0 = RandomDataGenerator{rows, cols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_0);
size_t f_rows = rows * frac;
SketchContainer sketch_1(n_bins, cols, f_rows, 0);
HostDeviceVector<float> storage_1;
std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_1);
auto data_1 = storage_1.DeviceSpan();
auto tuple_it = thrust::make_tuple(
thrust::make_counting_iterator<size_t>(0ul), data_1.data());
using Tuple = thrust::tuple<size_t, float>;
auto it = thrust::make_zip_iterator(tuple_it);
thrust::transform(thrust::device, it, it + data_1.size(), data_1.data(),
[=] __device__(Tuple const &tuple) {
auto i = thrust::get<0>(tuple);
if (thrust::get<0>(tuple) % 2 == 0) {
return 0.0f;
} else {
return thrust::get<1>(tuple);
}
});
data::CupyAdapter adapter_1(interface_str_1);
AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_1);
size_t size_before_merge = sketch_0.Data().size();
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr());
auto columns_ptr = sketch_0.ColumnsPtr();
std::vector<bst_row_t> h_columns_ptr(columns_ptr.size());
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
sketch_0.Unique();
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch_0.Data().data(),
sketch_0.Data().data() + sketch_0.Data().size(),
detail::SketchUnique{}));
}
TEST(GPUQuantile, MergeDuplicated) {
size_t n_bins = 256;
constexpr size_t kRows = 1000, kCols = 100;
for (float frac = 0.5; frac < 2.5; frac += 0.5) {
TestMergeDuplicated(n_bins, kRows, kCols, frac);
}
}
void InitRabitContext(std::string msg) {
auto n_gpus = AllVisibleGPUs();
auto port = std::getenv("DMLC_TRACKER_PORT");
std::string port_str;
if (port) {
port_str = port;
} else {
LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up.";
return;
}
std::vector<std::string> envs{
"DMLC_TRACKER_PORT=" + port_str,
"DMLC_TRACKER_URI=127.0.0.1",
"DMLC_NUM_WORKER=" + std::to_string(n_gpus)};
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
rabit::Init(3, c_envs);
}
TEST(GPUQuantile, AllReduceBasic) {
// This test is supposed to run by a python test that setups the environment.
std::string msg {"Skipping AllReduce test"};
#if defined(__linux__) && defined(XGBOOST_USE_NCCL)
InitRabitContext(msg);
auto n_gpus = AllVisibleGPUs();
auto world = rabit::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
return;
}
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
// Set up single node version;
SketchContainer sketch_on_single_node(n_bins, kCols, kRows, 0);
size_t intermediate_num_cuts =
std::min(kRows * world, static_cast<size_t>(n_bins * WQSketch::kFactor));
std::vector<SketchContainer> containers;
for (auto rank = 0; rank < world; ++rank) {
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
containers.emplace_back(n_bins, kCols, kRows, 0);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&containers.back());
}
for (auto& sketch : containers) {
sketch.Prune(intermediate_num_cuts);
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
sketch_on_single_node.FixError();
}
sketch_on_single_node.Unique();
TestQuantileElemRank(0, sketch_on_single_node.Data(),
sketch_on_single_node.ColumnsPtr());
// Set up distributed version. We rely on using rank as seed to generate
// the exact same copy of data.
auto rank = rabit::GetRank();
SketchContainer sketch_distributed(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.Unique();
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
sketch_on_single_node.ColumnsPtr().size());
ASSERT_EQ(sketch_distributed.Data().size(),
sketch_on_single_node.Data().size());
TestQuantileElemRank(0, sketch_distributed.Data(),
sketch_distributed.ColumnsPtr());
std::vector<SketchEntry> single_node_data(
sketch_on_single_node.Data().size());
dh::CopyDeviceSpanToVector(&single_node_data, sketch_on_single_node.Data());
std::vector<SketchEntry> distributed_data(sketch_distributed.Data().size());
dh::CopyDeviceSpanToVector(&distributed_data, sketch_distributed.Data());
float Eps = 2e-4 * world;
for (size_t i = 0; i < single_node_data.size(); ++i) {
ASSERT_NEAR(single_node_data[i].value, distributed_data[i].value, Eps);
ASSERT_NEAR(single_node_data[i].rmax, distributed_data[i].rmax, Eps);
ASSERT_NEAR(single_node_data[i].rmin, distributed_data[i].rmin, Eps);
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
}
});
rabit::Finalize();
#else
LOG(WARNING) << msg;
return;
#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL)
}
TEST(GPUQuantile, SameOnAllWorkers) {
std::string msg {"Skipping SameOnAllWorkers test"};
#if defined(__linux__) && defined(XGBOOST_USE_NCCL)
InitRabitContext(msg);
auto world = rabit::GetWorldSize();
auto n_gpus = AllVisibleGPUs();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
return;
}
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) {
auto rank = rabit::GetRank();
SketchContainer sketch_distributed(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.Unique();
TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr());
// Test for all workers having the same sketch.
size_t n_data = sketch_distributed.Data().size();
rabit::Allreduce<rabit::op::Max>(&n_data, 1);
ASSERT_EQ(n_data, sketch_distributed.Data().size());
size_t size_as_float =
sketch_distributed.Data().size_bytes() / sizeof(float);
auto local_data = Span<float const>{
reinterpret_cast<float const *>(sketch_distributed.Data().data()),
size_as_float};
dh::caching_device_vector<float> all_workers(size_as_float * world);
thrust::fill(all_workers.begin(), all_workers.end(), 0);
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
dh::AllReducer reducer;
reducer.Init(0);
reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(),
all_workers.size());
reducer.Synchronize();
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());
dh::CopyDeviceSpanToVector(&h_base_line, base_line);
size_t offset = 0;
for (size_t i = 0; i < world; ++i) {
auto comp = dh::ToSpan(all_workers).subspan(offset, size_as_float);
std::vector<float> h_comp(comp.size());
dh::CopyDeviceSpanToVector(&h_comp, comp);
ASSERT_EQ(comp.size(), base_line.size());
for (size_t j = 0; j < h_comp.size(); ++j) {
ASSERT_NEAR(h_base_line[j], h_comp[j], kRtEps);
}
offset += size_as_float;
}
});
#else
LOG(WARNING) << msg;
return;
#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL)
}
} // namespace common
} // namespace xgboost

View File

@ -422,11 +422,11 @@ TEST(Span, Subspan) {
ASSERT_EQ(s4.size(), s1.size() - 2);
EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan(16, 0), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n");
auto constexpr kOne = static_cast<Span<int, 4>::index_type>(-1);
EXPECT_DEATH(s1.subspan<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan<16>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan<17>(), "\\[xgboost\\] Condition .* failed.\n");
}
TEST(Span, Compare) {

0
tests/cpp/common/test_threading_utils.cc Executable file → Normal file
View File

View File

@ -168,7 +168,9 @@ class SimpleRealUniformDistribution {
ResultT operator()(GeneratorT* rng) const {
ResultT tmp = GenerateCanonical<std::numeric_limits<ResultT>::digits,
GeneratorT>(rng);
return (tmp * (upper_ - lower_)) + lower_;
auto ret = (tmp * (upper_ - lower_)) + lower_;
// Correct floating point error.
return std::max(ret, lower_);
}
};

View File

@ -1,4 +1,5 @@
[pytest]
markers =
mgpu: Mark a test that requires multiple GPUs to run.
ci: Mark a test that runs only on CI.
ci: Mark a test that runs only on CI.
gtest: Mark a test that requires C++ Google Test executable.

View File

@ -1,8 +1,12 @@
import sys
import os
import pytest
import numpy as np
import unittest
import xgboost
import subprocess
from hypothesis import given, strategies, settings, note
from test_gpu_updaters import parameter_strategy
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -12,11 +16,13 @@ from test_with_dask import run_empty_dmatrix # noqa
from test_with_dask import generate_array # noqa
import testing as tm # noqa
try:
import dask.dataframe as dd
from xgboost import dask as dxgb
from dask_cuda import LocalCUDACluster
from dask.distributed import Client
from dask import array as da
import cudf
except ImportError:
pass
@ -42,7 +48,8 @@ class TestDistributedGPU(unittest.TestCase):
y = y.map_partitions(cudf.from_pandas)
dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True},
dtrain=dtrain,
evals=[(dtrain, 'X')],
num_boost_round=4)
@ -61,7 +68,8 @@ class TestDistributedGPU(unittest.TestCase):
xgboost.DMatrix(X.compute()))
cp.testing.assert_allclose(single_node, predictions)
np.testing.assert_allclose(single_node, series_predictions.to_array())
np.testing.assert_allclose(single_node,
series_predictions.to_array())
predt = dxgb.predict(client, out, X)
assert isinstance(predt, dd.Series)
@ -77,6 +85,41 @@ class TestDistributedGPU(unittest.TestCase):
cp.testing.assert_allclose(
predt.values.compute(), single_node)
@given(parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@settings(deadline=None)
@pytest.mark.mgpu
def test_gpu_hist(self, params, num_rounds, dataset):
with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client:
params['tree_method'] = 'gpu_hist'
params = dataset.set_params(params)
# multi class doesn't handle empty dataset well (empty
# means at least 1 worker has data).
if params['objective'] == "multi:softmax":
return
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return
chunk = 128
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
else:
w = None
m = dxgb.DaskDMatrix(
client, data=X, label=y, weight=w)
history = dxgb.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
@pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu
def test_dask_array(self):
@ -89,7 +132,8 @@ class TestDistributedGPU(unittest.TestCase):
X = X.map_blocks(cp.asarray)
y = y.map_blocks(cp.asarray)
dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist'},
out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True},
dtrain=dtrain,
evals=[(dtrain, 'X')],
num_boost_round=2)
@ -107,12 +151,62 @@ class TestDistributedGPU(unittest.TestCase):
single_node,
inplace_predictions)
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu
def test_empty_dmatrix(self):
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
parameters = {'tree_method': 'gpu_hist'}
parameters = {'tree_method': 'gpu_hist',
'debug_synchronize': True}
run_empty_dmatrix(client, parameters)
def run_quantile(self, name):
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
exe = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/testxgboost', '../gpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
assert exe, 'No testxgboost executable found.'
test = "--gtest_filter=GPUQuantile." + name
def runit(worker_addr, rabit_args):
port = None
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port = arg.decode('utf-8')
port = port.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
workers = list(dxgb._get_client_workers(client).keys())
rabit_args = dxgb._get_rabit_args(workers, client)
futures = client.map(runit,
workers,
pure=False,
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from GPUQuantile') != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.mgpu
@pytest.mark.gtest
def test_quantile_basic(self):
self.run_quantile('AllReduceBasic')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.mgpu
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self):
self.run_quantile('SameOnAllWorkers')

View File

@ -1,10 +1,12 @@
# coding: utf-8
import os
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
from xgboost.compat import DASK_INSTALLED
from hypothesis import strategies
from hypothesis.extra.numpy import arrays
from joblib import Memory
from sklearn import datasets
import tempfile
import xgboost as xgb
import numpy as np
@ -123,10 +125,15 @@ class TestDataset:
return xgb.DeviceQuantileDMatrix(X, y, w)
def get_external_dmat(self):
np.savetxt('tmptmp_1234.csv', np.hstack((self.y.reshape(len(self.y), 1), self.X)),
delimiter=',')
return xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_',
weight=self.w)
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'tmptmp_1234.csv')
np.savetxt(path,
np.hstack((self.y.reshape(len(self.y), 1), self.X)),
delimiter=',')
uri = path + '?format=csv&label_column=0#tmptmp_'
# The uri looks like:
# 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_'
return xgb.DMatrix(uri, weight=self.w)
def __repr__(self):
return self.name
@ -181,6 +188,7 @@ def _dataset_and_weight(draw):
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
return data
# A strategy for drawing from a set of example datasets
# May add random weights to the dataset
dataset_strategy = _dataset_and_weight()