Implement GK sketching on GPU. (#5846)

* Implement GK sketching on GPU.
* Strong tests on quantile building.
* Handle sparse dataset by binary searching the column index.
* Hypothesis test on dask.
This commit is contained in:
Jiaming Yuan 2020-07-07 12:16:21 +08:00 committed by GitHub
parent ac3f0e78dc
commit 048d969be4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 2045 additions and 405 deletions

1
Jenkinsfile vendored
View File

@ -313,6 +313,7 @@ def TestPythonGPU(args) {
nodeReq = (args.multi_gpu) ? 'linux && mgpu' : 'linux && gpu' nodeReq = (args.multi_gpu) ? 'linux && mgpu' : 'linux && gpu'
node(nodeReq) { node(nodeReq) {
unstash name: 'xgboost_whl_cuda10' unstash name: 'xgboost_whl_cuda10'
unstash name: 'xgboost_cpp_tests'
unstash name: 'srcs' unstash name: 'srcs'
echo "Test Python GPU: CUDA ${args.cuda_version}" echo "Test Python GPU: CUDA ${args.cuda_version}"
def container_type = "gpu" def container_type = "gpu"

View File

@ -573,8 +573,8 @@ class Span {
XGBOOST_DEVICE auto subspan() const -> // NOLINT XGBOOST_DEVICE auto subspan() const -> // NOLINT
Span<element_type, Span<element_type,
detail::ExtentValue<Extent, Offset, Count>::value> { detail::ExtentValue<Extent, Offset, Count>::value> {
SPAN_CHECK(Offset < size() || size() == 0); SPAN_CHECK((Count == dynamic_extent) ?
SPAN_CHECK(Count == dynamic_extent || (Offset + Count <= size())); (Offset <= size()) : (Offset + Count <= size()));
return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count}; return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
} }
@ -582,9 +582,8 @@ class Span {
XGBOOST_DEVICE Span<element_type, dynamic_extent> subspan( // NOLINT XGBOOST_DEVICE Span<element_type, dynamic_extent> subspan( // NOLINT
index_type _offset, index_type _offset,
index_type _count = dynamic_extent) const { index_type _count = dynamic_extent) const {
SPAN_CHECK(_offset < size() || size() == 0); SPAN_CHECK((_count == dynamic_extent) ?
SPAN_CHECK((_count == dynamic_extent) || (_offset + _count <= size())); (_offset <= size()) : (_offset + _count <= size()));
return {data() + _offset, _count == return {data() + _offset, _count ==
dynamic_extent ? size() - _offset : _count}; dynamic_extent ? size() - _offset : _count};
} }

View File

@ -78,6 +78,33 @@ void AllReducer::Init(int _device_ordinal) {
#endif // XGBOOST_USE_NCCL #endif // XGBOOST_USE_NCCL
} }
void AllReducer::AllGather(void const *data, size_t length_bytes,
std::vector<size_t> *segments,
dh::caching_device_vector<char> *recvbuf) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
size_t world = rabit::GetWorldSize();
segments->clear();
segments->resize(world, 0);
segments->at(rabit::GetRank()) = length_bytes;
rabit::Allreduce<rabit::op::Max>(segments->data(), segments->size());
auto total_bytes = std::accumulate(segments->cbegin(), segments->cend(), 0);
recvbuf->resize(total_bytes);
size_t offset = 0;
safe_nccl(ncclGroupStart());
for (int32_t i = 0; i < world; ++i) {
size_t as_bytes = segments->at(i);
safe_nccl(
ncclBroadcast(data, recvbuf->data().get() + offset,
as_bytes, ncclChar, i, comm_, stream_));
offset += as_bytes;
}
safe_nccl(ncclGroupEnd());
#endif // XGBOOST_USE_NCCL
}
AllReducer::~AllReducer() { AllReducer::~AllReducer() {
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
if (initialised_) { if (initialised_) {

View File

@ -5,10 +5,16 @@
#include <thrust/device_ptr.h> #include <thrust/device_ptr.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/device_malloc_allocator.h> #include <thrust/device_malloc_allocator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/system/cuda/error.h> #include <thrust/system/cuda/error.h>
#include <thrust/system_error.h> #include <thrust/system_error.h>
#include <thrust/execution_policy.h>
#include <thrust/transform_scan.h>
#include <thrust/logical.h> #include <thrust/logical.h>
#include <thrust/gather.h> #include <thrust/gather.h>
#include <thrust/unique.h>
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
@ -53,6 +59,36 @@ __device__ __forceinline__ double atomicAdd(double* address, double val) { // N
} }
#endif #endif
namespace dh {
namespace detail {
template <size_t size>
struct AtomicDispatcher;
template <>
struct AtomicDispatcher<sizeof(uint32_t)> {
using Type = unsigned int; // NOLINT
static_assert(sizeof(Type) == sizeof(uint32_t), "Unsigned should be of size 32 bits.");
};
template <>
struct AtomicDispatcher<sizeof(uint64_t)> {
using Type = unsigned long long; // NOLINT
static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits.");
};
} // namespace detail
} // namespace dh
// atomicAdd is not defined for size_t.
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> * = // NOLINT
nullptr>
T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT
using Type = typename dh::detail::AtomicDispatcher<sizeof(T)>::Type;
Type ret = ::atomicAdd(reinterpret_cast<Type *>(addr), static_cast<Type>(v));
return static_cast<T>(ret);
}
namespace dh { namespace dh {
#define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__ #define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__
@ -291,10 +327,12 @@ public:
safe_cuda(cudaGetDevice(&current_device)); safe_cuda(cudaGetDevice(&current_device));
stats_.RegisterDeallocation(ptr, n, current_device); stats_.RegisterDeallocation(ptr, n, current_device);
} }
size_t PeakMemory() size_t PeakMemory() const {
{
return stats_.peak_allocated_bytes; return stats_.peak_allocated_bytes;
} }
size_t CurrentlyAllocatedBytes() const {
return stats_.currently_allocated_bytes;
}
void Clear() void Clear()
{ {
stats_ = DeviceStats(); stats_ = DeviceStats();
@ -529,7 +567,6 @@ class AllReducer {
bool initialised_ {false}; bool initialised_ {false};
size_t allreduce_bytes_ {0}; // Keep statistics of the number of bytes communicated size_t allreduce_bytes_ {0}; // Keep statistics of the number of bytes communicated
size_t allreduce_calls_ {0}; // Keep statistics of the number of reduce calls size_t allreduce_calls_ {0}; // Keep statistics of the number of reduce calls
std::vector<size_t> host_data_; // Used for all reduce on host
#ifdef XGBOOST_USE_NCCL #ifdef XGBOOST_USE_NCCL
ncclComm_t comm_; ncclComm_t comm_;
cudaStream_t stream_; cudaStream_t stream_;
@ -569,6 +606,27 @@ class AllReducer {
#endif #endif
} }
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(void const* data, size_t length_bytes,
std::vector<size_t>* segments, dh::caching_device_vector<char>* recvbuf);
void AllGather(uint32_t const* data, size_t length,
dh::caching_device_vector<uint32_t>* recvbuf) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
size_t world = rabit::GetWorldSize();
recvbuf->resize(length * world);
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32,
comm_, stream_));
#endif // XGBOOST_USE_NCCL
}
/** /**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing * \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms. * streams or comms.
@ -607,6 +665,40 @@ class AllReducer {
#endif #endif
} }
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
#endif
}
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
#endif
}
// Specialization for size_t, which is implementation defined so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
#endif
}
/** /**
* \fn void Synchronize() * \fn void Synchronize()
* *
@ -886,9 +978,86 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
// Thrust version of this function causes error on Windows // Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT> template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator( XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) { IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func); return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
} }
template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) -
1 - first;
return segment_id;
}
template <typename T>
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
namespace detail {
template <typename Key, typename KeyOutIt>
struct SegmentedUniqueReduceOp {
KeyOutIt key_out;
__device__ Key const& operator()(Key const& key) const {
auto constexpr kOne = static_cast<std::remove_reference_t<decltype(*(key_out + key.first))>>(1);
atomicAdd(&(*(key_out + key.first)), kOne);
return key;
}
};
} // namespace detail
/* \brief Segmented unique function. Keys are pointers to segments with key_segments_last -
* key_segments_first = n_segments + 1.
*
* \pre Input segment and output segment must not overlap.
*
* \param key_segments_first Beginning iterator of segments.
* \param key_segments_last End iterator of segments.
* \param val_first Beginning iterator of values.
* \param val_last End iterator of values.
* \param key_segments_out Output iterator of segments.
* \param val_out Output iterator of values.
*
* \return Number of unique values in total.
*/
template <typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename Comp>
size_t
SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
Comp comp) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
dh::XGBCachingDeviceAllocator<char> alloc;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
[=] __device__(size_t i) {
size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i);
return thrust::make_pair(seg, *(val_first + i));
});
size_t segments_len = key_segments_last - key_segments_first;
thrust::fill(thrust::device, key_segments_out, key_segments_out + segments_len, 0);
size_t n_inputs = std::distance(val_first, val_last);
// Reduce the number of uniques elements per segment, avoid creating an intermediate
// array for `reduce_by_key`. It's limited by the types that atomicAdd supports. For
// example, size_t is not supported as of CUDA 10.2.
auto reduce_it = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
detail::SegmentedUniqueReduceOp<Key, KeyOutIt>{key_segments_out});
auto uniques_ret = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), unique_key_it, unique_key_it + n_inputs,
val_first, reduce_it, val_out,
[=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) {
// In the same segment.
return comp(l.second, r.second);
}
return false;
});
auto n_uniques = uniques_ret.second - val_out;
CHECK_LE(n_uniques, n_inputs);
thrust::exclusive_scan(thrust::cuda::par(alloc), key_segments_out,
key_segments_out + segments_len, key_segments_out, 0);
return n_uniques;
}
} // namespace dh } // namespace dh

View File

@ -158,7 +158,6 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
uint32_t beg_col, uint32_t end_col, uint32_t beg_col, uint32_t end_col,
uint32_t thread_id) { uint32_t thread_id) {
CHECK_GE(end_col, beg_col); CHECK_GE(end_col, beg_col);
constexpr float kFactor = 8;
// Data groups, used in ranking. // Data groups, used in ranking.
std::vector<bst_uint> const& group_ptr = info.group_ptr_; std::vector<bst_uint> const& group_ptr = info.group_ptr_;
@ -175,11 +174,12 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info,
max_num_bins); max_num_bins);
if (n_bins == 0) { if (n_bins == 0) {
// cut_ptrs_ is initialized with a zero, so there's always an element at the back // cut_ptrs_ is initialized with a zero, so there's always an element at the back
CHECK_GE(local_ptrs.size(), 1);
local_ptrs.emplace_back(local_ptrs.back()); local_ptrs.emplace_back(local_ptrs.back());
continue; continue;
} }
sketch.Init(info.num_row_, 1.0 / (n_bins * kFactor)); sketch.Init(info.num_row_, 1.0 / (n_bins * WQSketch::kFactor));
for (auto const& entry : column) { for (auto const& entry : column) {
uint32_t weight_ind = 0; uint32_t weight_ind = 0;
if (use_group_ind) { if (use_group_ind) {
@ -329,7 +329,6 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
const MetaInfo& info = p_fmat->Info(); const MetaInfo& info = p_fmat->Info();
// safe factor for better accuracy // safe factor for better accuracy
constexpr int kFactor = 8;
std::vector<WQSketch> sketchs; std::vector<WQSketch> sketchs;
const int nthread = omp_get_max_threads(); const int nthread = omp_get_max_threads();
@ -339,7 +338,7 @@ void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) {
unsigned const ncol = static_cast<unsigned>(info.num_col_); unsigned const ncol = static_cast<unsigned>(info.num_col_);
sketchs.resize(info.num_col_); sketchs.resize(info.num_col_);
for (auto& s : sketchs) { for (auto& s : sketchs) {
s.Init(info.num_row_, 1.0 / (max_num_bins * kFactor)); s.Init(info.num_row_, 1.0 / (max_num_bins * WQSketch::kFactor));
} }
// Data groups, used in ranking. // Data groups, used in ranking.
@ -410,9 +409,8 @@ void DenseCuts::Init
// This allows efficient training on wide data // This allows efficient training on wide data
size_t global_max_rows = max_rows; size_t global_max_rows = max_rows;
rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1); rabit::Allreduce<rabit::op::Sum>(&global_max_rows, 1);
constexpr int kFactor = 8;
size_t intermediate_num_cuts = size_t intermediate_num_cuts =
std::min(global_max_rows, static_cast<size_t>(max_num_bins * kFactor)); std::min(global_max_rows, static_cast<size_t>(max_num_bins * WQSketch::kFactor));
// gather the histogram data // gather the histogram data
rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer; rabit::SerializeReducer<WQSketch::SummaryContainer> sreducer;
std::vector<WQSketch::SummaryContainer> summary_array; std::vector<WQSketch::SummaryContainer> summary_array;

View File

@ -8,6 +8,7 @@
#include <thrust/functional.h> #include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h> #include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h> #include <thrust/iterator/transform_iterator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/reduce.h> #include <thrust/reduce.h>
#include <thrust/sort.h> #include <thrust/sort.h>
#include <thrust/binary_search.h> #include <thrust/binary_search.h>
@ -31,21 +32,20 @@ namespace common {
constexpr float SketchContainer::kFactor; constexpr float SketchContainer::kFactor;
namespace detail {
// Count the entries in each column and exclusive scan // Count the entries in each column and exclusive scan
void ExtractCuts(int device, void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data, Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan, Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts) { Span<SketchEntry> out_cuts) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input // Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature; size_t column_idx = dh::SegmentId(cuts_ptr, idx);
size_t column_size = size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts = size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
min(static_cast<size_t>(num_cuts_per_feature), column_size); size_t cut_idx = idx - cuts_ptr[column_idx];
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;
Span<Entry const> column_entries = Span<Entry const> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size); sorted_data.subspan(column_sizes_scan[column_idx], column_size);
size_t rank = (column_entries.size() * cut_idx) / size_t rank = (column_entries.size() * cut_idx) /
@ -55,31 +55,20 @@ void ExtractCuts(int device,
}); });
} }
/** void ExtractWeightedCutsSparse(int device,
* \brief Extracts the cuts from sorted data, considering weights. common::Span<SketchContainer::OffsetT const> cuts_ptr,
*
* \param device The device.
* \param cuts Output cuts.
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns.
* \param weights_scan Inclusive scan of weights for each entry in sorted_data.
* \param column_sizes_scan Describes the boundaries of column segments in sorted data.
*/
void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature,
Span<Entry> sorted_data, Span<Entry> sorted_data,
Span<float> weights_scan, Span<float> weights_scan,
Span<size_t> column_sizes_scan, Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts) { Span<SketchEntry> cuts) {
dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) { dh::LaunchN(device, cuts.size(), [=] __device__(size_t idx) {
// Each thread is responsible for obtaining one cut from the sorted input // Each thread is responsible for obtaining one cut from the sorted input
size_t column_idx = idx / num_cuts_per_feature; size_t column_idx = dh::SegmentId(cuts_ptr, idx);
size_t column_size = size_t column_size =
column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx]; column_sizes_scan[column_idx + 1] - column_sizes_scan[column_idx];
size_t num_available_cuts = size_t num_available_cuts = cuts_ptr[column_idx + 1] - cuts_ptr[column_idx];
min(static_cast<size_t>(num_cuts_per_feature), column_size); size_t cut_idx = idx - cuts_ptr[column_idx];
size_t cut_idx = idx % num_cuts_per_feature;
if (cut_idx >= num_available_cuts) return;
Span<Entry> column_entries = Span<Entry> column_entries =
sorted_data.subspan(column_sizes_scan[column_idx], column_size); sorted_data.subspan(column_sizes_scan[column_idx], column_size);
@ -109,7 +98,7 @@ void ExtractWeightedCuts(int device,
max(static_cast<size_t>(0), max(static_cast<size_t>(0),
min(sample_idx, column_entries.size() - 1)); min(sample_idx, column_entries.size() - 1));
} }
// repeated values will be filtered out on the CPU // repeated values will be filtered out later.
bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; bst_float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f;
bst_float rmax = column_weights_scan[sample_idx]; bst_float rmax = column_weights_scan[sample_idx];
cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin, cuts[idx] = WQSketch::Entry(rmin, rmax, rmax - rmin,
@ -117,31 +106,71 @@ void ExtractWeightedCuts(int device,
}); });
} }
void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) {
SketchContainer* sketch_container, int num_cuts, double eps = 1.0 / (WQSketch::kFactor * max_bins);
size_t num_columns) { size_t dummy_nlevel;
dh::XGBCachingDeviceAllocator<char> alloc; size_t num_cuts;
const auto& host_data = page.data.ConstHostVector(); WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin, num_rows, eps, &dummy_nlevel, &num_cuts);
host_data.begin() + end); return std::min(num_cuts, num_rows);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), }
sorted_entries.end(), EntryCompareOp());
dh::caching_device_vector<size_t> column_sizes_scan; size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns,
GetColumnSizesScan(device, &column_sizes_scan, size_t max_bins, size_t nnz) {
{sorted_entries.data().get(), sorted_entries.size()}, auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows);
num_columns); auto if_dense = num_columns * per_column;
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan); auto result = std::min(nnz, if_dense);
return result;
}
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts); size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
ExtractCuts(device, num_cuts, size_t num_bins, bool with_weights) {
dh::ToSpan(sorted_entries), size_t peak = 0;
dh::ToSpan(column_sizes_scan), // 0. Allocate cut pointer in quantile container by increasing: n_columns + 1
dh::ToSpan(cuts)); size_t total = (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 1. Copy and sort: 2 * bytes_per_element * shape
total += BytesPerElement(with_weights) * num_rows * num_columns;
peak = std::max(peak, total);
// 2. Deallocate bytes_per_element * shape due to reusing memory in sort.
total -= BytesPerElement(with_weights) * num_rows * num_columns / 2;
// 3. Allocate colomn size scan by increasing: n_columns + 1
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 4. Allocate cut pointer by increasing: n_columns + 1
total += (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 5. Allocate cuts: assuming rows is greater than bins: n_columns * limit_size
total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry);
// 6. Deallocate copied entries by reducing: bytes_per_element * shape.
peak = std::max(peak, total);
total -= (BytesPerElement(with_weights) * num_rows * num_columns) / 2;
// 7. Deallocate column size scan.
peak = std::max(peak, total);
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 8. Deallocate cut size scan.
total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT);
// 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) *
// n_columns + n_columns + n_columns + 1
total += std::min(num_rows, num_bins) * num_columns * sizeof(float);
total += num_columns *
sizeof(std::remove_reference_t<decltype(
std::declval<HistogramCuts>().MinValues())>::value_type);
total += (num_columns + 1) *
sizeof(std::remove_reference_t<decltype(
std::declval<HistogramCuts>().Ptrs())>::value_type);
peak = std::max(peak, total);
// add cuts into sketches return peak;
thrust::host_vector<SketchEntry> host_cuts(cuts); }
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
bst_row_t num_rows, size_t columns, size_t nnz, int device,
size_t num_cuts, bool has_weight) {
if (sketch_batch_num_elements == 0) {
auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight);
// use up to 80% of available space
sketch_batch_num_elements = (dh::AvailableMemory(device) -
required_memory * 0.8);
}
return sketch_batch_num_elements;
} }
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc, void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
@ -150,7 +179,7 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
// Sort both entries and wegihts. // Sort both entries and wegihts.
thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(), thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(),
sorted_entries->end(), weights->begin(), sorted_entries->end(), weights->begin(),
EntryCompareOp()); detail::EntryCompareOp());
// Scan weights // Scan weights
thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc), thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc),
@ -160,6 +189,46 @@ void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
return a.index == b.index; return a.index == b.index;
}); });
} }
} // namespace detail
void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end,
SketchContainer *sketch_container, int num_cuts_per_feature,
size_t num_columns) {
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::caching_device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan;
data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
sorted_entries.data().get(),
[] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size());
detail::ExtractCutsSparse(device, d_cuts_ptr, dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts));
// add cuts into sketches
sorted_entries.clear();
sorted_entries.shrink_to_fit();
CHECK_EQ(sorted_entries.capacity(), 0);
CHECK_NE(cuts_ptr.Size(), 0);
sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
}
void ProcessWeightedBatch(int device, const SparsePage& page, void ProcessWeightedBatch(int device, const SparsePage& page,
Span<const float> weights, size_t begin, size_t end, Span<const float> weights, size_t begin, size_t end,
@ -204,40 +273,53 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
d_temp_weights[idx] = weights[ridx + base_rowid]; d_temp_weights[idx] = weights[ridx + base_rowid];
}); });
} }
SortByWeight(&alloc, &temp_weights, &sorted_entries); detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
dh::caching_device_vector<size_t> column_sizes_scan; dh::caching_device_vector<size_t> column_sizes_scan;
GetColumnSizesScan(device, &column_sizes_scan, data::IsValidFunctor dummy_is_valid(std::numeric_limits<float>::quiet_NaN());
{sorted_entries.data().get(), sorted_entries.size()}, auto batch_it = dh::MakeTransformIterator<data::COOTuple>(
num_columns); sorted_entries.data().get(),
thrust::host_vector<size_t> host_column_sizes_scan(column_sizes_scan); [] __device__(Entry const &e) -> data::COOTuple {
return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size.
});
detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature,
batch_it, dummy_is_valid,
0, sorted_entries.size(),
&cuts_ptr, &column_sizes_scan);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
// Extract cuts // Extract cuts
dh::caching_device_vector<SketchEntry> cuts(num_columns * num_cuts_per_feature); detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
ExtractWeightedCuts(device, num_cuts_per_feature,
dh::ToSpan(sorted_entries), dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights), dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan), dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts)); dh::ToSpan(cuts));
// add cuts into sketches // add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts); sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
} }
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements) { size_t sketch_batch_num_elements) {
// Configure batch size based on available memory // Configure batch size based on available memory
bool has_weights = dmat->Info().weights_.Size() > 0; bool has_weights = dmat->Info().weights_.Size() > 0;
size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_); size_t num_cuts_per_feature =
sketch_batch_num_elements = SketchBatchNumElements( detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_);
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, sketch_batch_num_elements,
dmat->Info().num_col_, device, num_cuts_per_feature, has_weights); dmat->Info().num_row_,
dmat->Info().num_col_,
dmat->Info().num_nonzero_,
device, num_cuts_per_feature, has_weights);
HistogramCuts cuts; HistogramCuts cuts;
DenseCuts dense_cuts(&cuts); DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(max_bins, dmat->Info().num_col_, SketchContainer sketch_container(max_bins, dmat->Info().num_col_,
dmat->Info().num_row_); dmat->Info().num_row_, device);
dmat->Info().weights_.SetDevice(device); dmat->Info().weights_.SetDevice(device);
for (const auto& batch : dmat->GetBatches<SparsePage>()) { for (const auto& batch : dmat->GetBatches<SparsePage>()) {
@ -261,8 +343,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
} }
} }
} }
sketch_container.MakeCuts(&cuts);
dense_cuts.Init(&sketch_container.sketches_, max_bins, dmat->Info().num_row_);
return cuts; return cuts;
} }
} // namespace common } // namespace common

View File

@ -1,5 +1,8 @@
/*! /*!
* Copyright 2020 XGBoost contributors * Copyright 2020 XGBoost contributors
*
* \brief Front end and utilities for GPU based sketching. Works on sliding window
* instead of stream.
*/ */
#ifndef COMMON_HIST_UTIL_CUH_ #ifndef COMMON_HIST_UTIL_CUH_
#define COMMON_HIST_UTIL_CUH_ #define COMMON_HIST_UTIL_CUH_
@ -7,74 +10,15 @@
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include "hist_util.h" #include "hist_util.h"
#include "threading_utils.h" #include "quantile.cuh"
#include "device_helpers.cuh" #include "device_helpers.cuh"
#include "timer.h"
#include "../data/device_adapter.cuh" #include "../data/device_adapter.cuh"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
using WQSketch = DenseCuts::WQSketch; namespace detail {
using SketchEntry = WQSketch::Entry;
/*!
* \brief A container that holds the device sketches across all
* sparse page batches which are distributed to different devices.
* As sketches are aggregated by column, the mutex guards
* multiple devices pushing sketch summary for the same column
* across distinct rows.
*/
struct SketchContainer {
std::vector<DenseCuts::WQSketch> sketches_; // NOLINT
static constexpr int kOmpNumColsParallelizeLimit = 1000;
static constexpr float kFactor = 8;
SketchContainer(int max_bin, size_t num_columns, size_t num_rows) {
// Initialize Sketches for this dmatrix
sketches_.resize(num_columns);
#pragma omp parallel for schedule(static) if (num_columns > kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < num_columns; ++icol) { // NOLINT
sketches_[icol].Init(num_rows, 1.0 / (8 * max_bin));
}
}
/**
* \brief Pushes cuts to the sketches.
*
* \param entries_per_column The entries per column.
* \param entries Vector of cuts from all columns, length
* entries_per_column * num_columns. \param column_scan Exclusive scan
* of column sizes. Used to detect cases where there are fewer entries than we
* have storage for.
*/
void Push(size_t entries_per_column,
const thrust::host_vector<SketchEntry>& entries,
const thrust::host_vector<size_t>& column_scan) {
#pragma omp parallel for schedule(static) if (sketches_.size() > SketchContainer::kOmpNumColsParallelizeLimit) // NOLINT
for (int icol = 0; icol < sketches_.size(); ++icol) {
size_t column_size = column_scan[icol + 1] - column_scan[icol];
if (column_size == 0) continue;
WQuantileSketch<bst_float, bst_float>::SummaryContainer summary;
size_t num_available_cuts =
std::min(size_t(entries_per_column), column_size);
summary.Reserve(num_available_cuts);
summary.MakeFromSorted(&entries[entries_per_column * icol],
num_available_cuts);
sketches_[icol].PushSummary(summary);
}
}
// Prevent copying/assigning/moving this as its internals can't be
// assigned/copied/moved
SketchContainer(const SketchContainer&) = delete;
SketchContainer(SketchContainer&& that) {
std::swap(sketches_, that.sketches_);
}
SketchContainer& operator=(const SketchContainer&) = delete;
SketchContainer& operator=(SketchContainer&&) = delete;
};
struct EntryCompareOp { struct EntryCompareOp {
__device__ bool operator()(const Entry& a, const Entry& b) { __device__ bool operator()(const Entry& a, const Entry& b) {
if (a.index == b.index) { if (a.index == b.index) {
@ -88,100 +32,105 @@ struct EntryCompareOp {
* \brief Extracts the cuts from sorted data. * \brief Extracts the cuts from sorted data.
* *
* \param device The device. * \param device The device.
* \param cuts Output cuts * \param cuts_ptr Column pointers to CSC structured cuts
* \param num_cuts_per_feature Number of cuts per feature.
* \param sorted_data Sorted entries in segments of columns * \param sorted_data Sorted entries in segments of columns
* \param column_sizes_scan Describes the boundaries of column segments in * \param column_sizes_scan Describes the boundaries of column segments in sorted data
* sorted data * \param out_cuts Output cut values
*/ */
void ExtractCuts(int device, void ExtractCutsSparse(int device, common::Span<SketchContainer::OffsetT const> cuts_ptr,
size_t num_cuts_per_feature,
Span<Entry const> sorted_data, Span<Entry const> sorted_data,
Span<size_t const> column_sizes_scan, Span<size_t const> column_sizes_scan,
Span<SketchEntry> out_cuts); Span<SketchEntry> out_cuts);
// Count the entries in each column and exclusive scan /**
inline void GetColumnSizesScan(int device, * \brief Extracts the cuts from sorted data, considering weights.
dh::caching_device_vector<size_t>* column_sizes_scan, *
Span<const Entry> entries, size_t num_columns) { * \param device The device.
column_sizes_scan->resize(num_columns + 1, 0); * \param cuts_ptr Column pointers to CSC structured cuts
auto d_column_sizes_scan = column_sizes_scan->data().get(); * \param sorted_data Sorted entries in segments of columns.
auto d_entries = entries.data(); * \param weights_scan Inclusive scan of weights for each entry in sorted_data.
dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) { * \param column_sizes_scan Describes the boundaries of column segments in sorted data.
auto& e = d_entries[idx]; * \param cuts Output cuts.
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT */
&d_column_sizes_scan[e.index]), void ExtractWeightedCutsSparse(int device,
static_cast<unsigned long long>(1)); // NOLINT common::Span<SketchContainer::OffsetT const> cuts_ptr,
}); Span<Entry> sorted_data,
dh::XGBCachingDeviceAllocator<char> alloc; Span<float> weights_scan,
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), Span<size_t> column_sizes_scan,
column_sizes_scan->end(), column_sizes_scan->begin()); Span<SketchEntry> cuts);
}
// For adapter. // Get column size from adapter batch and for output cuts.
template <typename Iter> template <typename Iter>
void GetColumnSizesScan(int device, size_t num_columns, void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature,
Iter batch_iter, data::IsValidFunctor is_valid, Iter batch_iter, data::IsValidFunctor is_valid,
size_t begin, size_t end, size_t begin, size_t end,
HostDeviceVector<SketchContainer::OffsetT> *cuts_ptr,
dh::caching_device_vector<size_t>* column_sizes_scan) { dh::caching_device_vector<size_t>* column_sizes_scan) {
dh::XGBCachingDeviceAllocator<char> alloc;
column_sizes_scan->resize(num_columns + 1, 0); column_sizes_scan->resize(num_columns + 1, 0);
cuts_ptr->SetDevice(device);
cuts_ptr->Resize(num_columns + 1, 0);
dh::XGBCachingDeviceAllocator<char> alloc;
auto d_column_sizes_scan = column_sizes_scan->data().get(); auto d_column_sizes_scan = column_sizes_scan->data().get();
dh::LaunchN(device, end - begin, [=] __device__(size_t idx) { dh::LaunchN(device, end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx]; auto e = batch_iter[begin + idx];
if (is_valid(e)) { if (is_valid(e)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast<size_t>(1));
&d_column_sizes_scan[e.column_idx]),
static_cast<unsigned long long>(1)); // NOLINT
} }
}); });
// Calculate cuts CSC pointer
auto cut_ptr_it = dh::MakeTransformIterator<size_t>(
column_sizes_scan->begin(), [=] __device__(size_t column_size) {
return thrust::min(num_cuts_per_feature, column_size);
});
thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it,
cut_ptr_it + column_sizes_scan->size(),
cuts_ptr->DevicePointer());
thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(),
column_sizes_scan->end(), column_sizes_scan->begin()); column_sizes_scan->end(), column_sizes_scan->begin());
} }
inline size_t BytesPerElement(bool has_weight) { inline size_t constexpr BytesPerElement(bool has_weight) {
// Double the memory usage for sorting. We need to assign weight for each element, so // Double the memory usage for sorting. We need to assign weight for each element, so
// sizeof(float) is added to all elements. // sizeof(float) is added to all elements.
return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2;
} }
inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements, /* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements`
size_t columns, int device, * directly if it's not 0.
size_t num_cuts, bool has_weight) { */
if (sketch_batch_num_elements == 0) { size_t SketchBatchNumElements(size_t sketch_batch_num_elements,
size_t bytes_per_element = BytesPerElement(has_weight); bst_row_t num_rows, size_t columns, size_t nnz, int device,
size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry); size_t num_cuts, bool has_weight);
size_t bytes_num_columns = (columns + 1) * sizeof(size_t);
// use up to 80% of available space
sketch_batch_num_elements = (dh::AvailableMemory(device) -
bytes_cuts - bytes_num_columns) *
0.8 / bytes_per_element;
}
return sketch_batch_num_elements;
}
// Compute number of sample cuts needed on local node to maintain accuracy // Compute number of sample cuts needed on local node to maintain accuracy
// We take more cuts than needed and then reduce them later // We take more cuts than needed and then reduce them later
inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) { size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows);
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows);
}
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements = 0);
/* \brief Estimate required memory for each sliding window.
*
* It's not precise as to obtain exact memory usage for sparse dataset we need to walk
* through the whole dataset first. Also if data is from host DMatrix, we copy the
* weight, group and offset on first batch, which is not considered in the function.
*
* \param num_rows Number of rows in this worker.
* \param num_columns Number of columns for this dataset.
* \param nnz Number of non-zero element. Put in something greater than rows *
* cols if nnz is unknown.
* \param num_bins Number of histogram bins.
* \param with_weights Whether weight is used, works the same for ranking and other models.
*
* \return The estimated bytes
*/
size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz,
size_t num_bins, bool with_weights);
// Count the valid entries in each column and copy them out.
template <typename AdapterBatch, typename BatchIter> template <typename AdapterBatch, typename BatchIter>
void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
Range1d range, float missing, Range1d range, float missing,
size_t columns, int device, size_t columns, size_t cuts_per_feature, int device,
thrust::host_vector<size_t>* host_column_sizes_scan, HostDeviceVector<SketchContainer::OffsetT>* cut_sizes_scan,
dh::caching_device_vector<size_t>* column_sizes_scan, dh::caching_device_vector<size_t>* column_sizes_scan,
dh::caching_device_vector<Entry>* sorted_entries) { dh::caching_device_vector<Entry>* sorted_entries) {
auto entry_iter = dh::MakeTransformIterator<Entry>( auto entry_iter = dh::MakeTransformIterator<Entry>(
@ -191,16 +140,12 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
}); });
data::IsValidFunctor is_valid(missing); data::IsValidFunctor is_valid(missing);
// Work out how many valid entries we have in each column // Work out how many valid entries we have in each column
GetColumnSizesScan(device, columns, GetColumnSizesScan(device, columns, cuts_per_feature,
batch_iter, is_valid, batch_iter, is_valid,
range.begin(), range.end(), range.begin(), range.end(),
cut_sizes_scan,
column_sizes_scan); column_sizes_scan);
host_column_sizes_scan->resize(column_sizes_scan->size()); size_t num_valid = column_sizes_scan->back();
thrust::copy(column_sizes_scan->begin(), column_sizes_scan->end(),
host_column_sizes_scan->begin());
size_t num_valid = host_column_sizes_scan->back();
// Copy current subset of valid elements into temporary storage and sort // Copy current subset of valid elements into temporary storage and sort
sorted_entries->resize(num_valid); sorted_entries->resize(num_valid);
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
@ -208,6 +153,16 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter,
entry_iter + range.end(), sorted_entries->begin(), is_valid); entry_iter + range.end(), sorted_entries->begin(), is_valid);
} }
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries);
} // namespace detail
// Compute sketch on DMatrix.
// sketch_batch_num_elements 0 means autodetect. Only modify this for testing.
HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins,
size_t sketch_batch_num_elements = 0);
template <typename AdapterBatch> template <typename AdapterBatch>
void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
size_t begin, size_t end, float missing, size_t begin, size_t end, float missing,
@ -215,41 +170,33 @@ void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns,
// Copy current subset of valid elements into temporary storage and sort // Copy current subset of valid elements into temporary storage and sort
dh::caching_device_vector<Entry> sorted_entries; dh::caching_device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan; dh::caching_device_vector<size_t> column_sizes_scan;
thrust::host_vector<size_t> host_column_sizes_scan;
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>( auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu), thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); }); [=] __device__(size_t idx) { return batch.GetElement(idx); });
MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, columns, device, HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
&host_column_sizes_scan, detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing,
columns, num_cuts, device,
&cuts_ptr,
&column_sizes_scan, &column_sizes_scan,
&sorted_entries); &sorted_entries);
dh::XGBCachingDeviceAllocator<char> alloc; dh::XGBCachingDeviceAllocator<char> alloc;
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), EntryCompareOp()); sorted_entries.end(), detail::EntryCompareOp());
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
// Extract the cuts from all columns concurrently // Extract the cuts from all columns concurrently
dh::caching_device_vector<SketchEntry> cuts(columns * num_cuts); detail::ExtractCutsSparse(device, d_cuts_ptr,
ExtractCuts(device, num_cuts,
dh::ToSpan(sorted_entries), dh::ToSpan(sorted_entries),
dh::ToSpan(column_sizes_scan), dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts)); dh::ToSpan(cuts));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
// Push cuts into sketches stored in host memory sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
thrust::host_vector<SketchEntry> host_cuts(cuts);
sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan);
} }
void ExtractWeightedCuts(int device,
size_t num_cuts_per_feature,
Span<Entry> sorted_data,
Span<float> weights_scan,
Span<size_t> column_sizes_scan,
Span<SketchEntry> cuts);
void SortByWeight(dh::XGBCachingDeviceAllocator<char>* alloc,
dh::caching_device_vector<float>* weights,
dh::caching_device_vector<Entry>* sorted_entries);
template <typename Batch> template <typename Batch>
void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
int num_cuts_per_feature, int num_cuts_per_feature,
@ -268,10 +215,11 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
[=] __device__(size_t idx) { return batch.GetElement(idx); }); [=] __device__(size_t idx) { return batch.GetElement(idx); });
dh::caching_device_vector<Entry> sorted_entries; dh::caching_device_vector<Entry> sorted_entries;
dh::caching_device_vector<size_t> column_sizes_scan; dh::caching_device_vector<size_t> column_sizes_scan;
thrust::host_vector<size_t> host_column_sizes_scan; HostDeviceVector<SketchContainer::OffsetT> cuts_ptr;
MakeEntriesFromAdapter(batch, batch_iter, detail::MakeEntriesFromAdapter(batch, batch_iter,
{begin, end}, missing, columns, device, {begin, end}, missing,
&host_column_sizes_scan, columns, num_cuts_per_feature, device,
&cuts_ptr,
&column_sizes_scan, &column_sizes_scan,
&sorted_entries); &sorted_entries);
data::IsValidFunctor is_valid(missing); data::IsValidFunctor is_valid(missing);
@ -297,6 +245,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
is_valid); is_valid);
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
} else { } else {
CHECK_EQ(batch.NumRows(), weights.size());
auto const weight_iter = dh::MakeTransformIterator<float>( auto const weight_iter = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0lu), thrust::make_counting_iterator(0lu),
[=]__device__(size_t idx) -> float { [=]__device__(size_t idx) -> float {
@ -310,90 +259,114 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info,
CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size());
} }
SortByWeight(&alloc, &temp_weights, &sorted_entries); detail::SortByWeight(&alloc, &temp_weights, &sorted_entries);
auto const& h_cuts_ptr = cuts_ptr.ConstHostVector();
auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan();
// Extract cuts // Extract cuts
dh::caching_device_vector<SketchEntry> cuts(columns * num_cuts_per_feature); dh::caching_device_vector<SketchEntry> cuts(h_cuts_ptr.back());
ExtractWeightedCuts(device, num_cuts_per_feature, detail::ExtractWeightedCutsSparse(device, d_cuts_ptr,
dh::ToSpan(sorted_entries), dh::ToSpan(sorted_entries),
dh::ToSpan(temp_weights), dh::ToSpan(temp_weights),
dh::ToSpan(column_sizes_scan), dh::ToSpan(column_sizes_scan),
dh::ToSpan(cuts)); dh::ToSpan(cuts));
sorted_entries.clear();
sorted_entries.shrink_to_fit();
// add cuts into sketches // add cuts into sketches
thrust::host_vector<SketchEntry> host_cuts(cuts); sketch_container->Push(cuts_ptr.ConstDeviceSpan(), &cuts);
sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan);
} }
template <typename AdapterT> template <typename AdapterT>
HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins,
float missing, float missing,
size_t sketch_batch_num_elements = 0) { size_t sketch_batch_num_elements = 0) {
size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows()); size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, adapter->NumRows());
CHECK(adapter->NumRows() != data::kAdapterUnknownSize); CHECK(adapter->NumRows() != data::kAdapterUnknownSize);
CHECK(adapter->NumColumns() != data::kAdapterUnknownSize); CHECK(adapter->NumColumns() != data::kAdapterUnknownSize);
adapter->BeforeFirst(); adapter->BeforeFirst();
adapter->Next(); adapter->Next();
auto& batch = adapter->Value(); auto& batch = adapter->Value();
sketch_batch_num_elements = SketchBatchNumElements( sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, sketch_batch_num_elements,
adapter->NumColumns(), adapter->DeviceIdx(), num_cuts, false); adapter->NumRows(), adapter->NumColumns(), std::numeric_limits<size_t>::max(),
adapter->DeviceIdx(),
num_cuts_per_feature, false);
// Enforce single batch // Enforce single batch
CHECK(!adapter->Next()); CHECK(!adapter->Next());
HistogramCuts cuts; HistogramCuts cuts;
DenseCuts dense_cuts(&cuts);
SketchContainer sketch_container(num_bins, adapter->NumColumns(), SketchContainer sketch_container(num_bins, adapter->NumColumns(),
adapter->NumRows()); adapter->NumRows(), adapter->DeviceIdx());
for (auto begin = 0ull; begin < batch.Size(); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
auto const& batch = adapter->Value(); auto const& batch = adapter->Value();
ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(), ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(),
begin, end, missing, &sketch_container, num_cuts); begin, end, missing, &sketch_container, num_cuts_per_feature);
} }
dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows()); sketch_container.MakeCuts(&cuts);
return cuts; return cuts;
} }
/*
* \brief Perform sketching on GPU.
*
* \param batch A batch from adapter.
* \param num_bins Bins per column.
* \param missing Floating point value that represents invalid value.
* \param sketch_container Container for output sketch.
* \param sketch_batch_num_elements Number of element per-sliding window, use it only for
* testing.
*/
template <typename Batch> template <typename Batch>
void AdapterDeviceSketch(Batch batch, int num_bins, void AdapterDeviceSketch(Batch batch, int num_bins,
float missing, int device, float missing, SketchContainer* sketch_container,
SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) { size_t sketch_batch_num_elements = 0) {
size_t num_rows = batch.NumRows(); size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols(); size_t num_cols = batch.NumCols();
size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
sketch_batch_num_elements = SketchBatchNumElements( int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, sketch_batch_num_elements,
num_cols, device, num_cuts, false); num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, false);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessSlidingWindow(batch, device, num_cols, ProcessSlidingWindow(batch, device, num_cols,
begin, end, missing, sketch_container, num_cuts); begin, end, missing, sketch_container, num_cuts_per_feature);
} }
} }
/*
* \brief Perform weighted sketching on GPU.
*
* When weight in info is empty, this function is equivalent to unweighted version.
*/
template <typename Batch> template <typename Batch>
void AdapterDeviceSketchWeighted(Batch batch, int num_bins, void AdapterDeviceSketchWeighted(Batch batch, int num_bins,
MetaInfo const& info, MetaInfo const& info,
float missing, float missing, SketchContainer* sketch_container,
int device,
SketchContainer* sketch_container,
size_t sketch_batch_num_elements = 0) { size_t sketch_batch_num_elements = 0) {
if (info.weights_.Size() == 0) {
return AdapterDeviceSketch(batch, num_bins, missing, sketch_container, sketch_batch_num_elements);
}
size_t num_rows = batch.NumRows(); size_t num_rows = batch.NumRows();
size_t num_cols = batch.NumCols(); size_t num_cols = batch.NumCols();
size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows);
sketch_batch_num_elements = SketchBatchNumElements( int32_t device = sketch_container->DeviceIdx();
sketch_batch_num_elements = detail::SketchBatchNumElements(
sketch_batch_num_elements, sketch_batch_num_elements,
num_cols, device, num_cuts, true); num_rows, num_cols, std::numeric_limits<size_t>::max(),
device, num_cuts_per_feature, true);
for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) {
size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements));
ProcessWeightedSlidingWindow(batch, info, ProcessWeightedSlidingWindow(batch, info,
num_cuts, num_cuts_per_feature,
CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end,
sketch_container); sketch_container);
} }

View File

@ -167,7 +167,7 @@ class CutsBuilder {
/*! \brief Cut configuration for sparse dataset. */ /*! \brief Cut configuration for sparse dataset. */
class SparseCuts : public CutsBuilder { class SparseCuts : public CutsBuilder {
/* \brief Distrbute columns to each thread according to number of entries. */ /* \brief Distribute columns to each thread according to number of entries. */
static std::vector<size_t> LoadBalance(SparsePage const& page, size_t const nthreads); static std::vector<size_t> LoadBalance(SparsePage const& page, size_t const nthreads);
Monitor monitor_; Monitor monitor_;

View File

@ -205,7 +205,7 @@ class HostDeviceVectorImpl {
// data is on the host // data is on the host
LazyResizeDevice(data_h_.size()); LazyResizeDevice(data_h_.size());
SetDevice(); SetDevice();
dh::safe_cuda(cudaMemcpy(data_d_->data().get(), dh::safe_cuda(cudaMemcpyAsync(data_d_->data().get(),
data_h_.data(), data_h_.data(),
data_d_->size() * sizeof(T), data_d_->size() * sizeof(T),
cudaMemcpyHostToDevice)); cudaMemcpyHostToDevice));

572
src/common/quantile.cu Normal file
View File

@ -0,0 +1,572 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include <thrust/unique.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/binary_search.h>
#include <thrust/transform_scan.h>
#include <thrust/execution_policy.h>
#include <memory>
#include <utility>
#include "xgboost/span.h"
#include "quantile.h"
#include "quantile.cuh"
#include "hist_util.h"
#include "device_helpers.cuh"
#include "common.h"
namespace xgboost {
namespace common {
using WQSketch = DenseCuts::WQSketch;
using SketchEntry = WQSketch::Entry;
// Algorithm 4 in XGBoost's paper, using binary search to find i.
__device__ SketchEntry BinarySearchQuery(Span<SketchEntry const> const& entries, float rank) {
assert(entries.size() >= 2);
rank *= 2;
if (rank < entries.front().rmin + entries.front().rmax) {
return entries.front();
}
if (rank >= entries.back().rmin + entries.back().rmax) {
return entries.back();
}
auto begin = dh::MakeTransformIterator<float>(
entries.begin(), [=] __device__(SketchEntry const &entry) {
return entry.rmin + entry.rmax;
});
auto end = begin + entries.size();
auto i = thrust::upper_bound(thrust::seq, begin + 1, end - 1, rank) - begin - 1;
if (rank < entries[i].RMinNext() + entries[i+1].RMaxPrev()) {
return entries[i];
} else {
return entries[i+1];
}
}
template <typename T>
void CopyTo(Span<T> out, Span<T const> src) {
CHECK_EQ(out.size(), src.size());
dh::safe_cuda(cudaMemcpyAsync(out.data(), src.data(),
out.size_bytes(),
cudaMemcpyDefault));
}
// Compute the merge path.
common::Span<thrust::tuple<uint64_t, uint64_t>> MergePath(
Span<SketchEntry const> const &d_x, Span<bst_row_t const> const &x_ptr,
Span<SketchEntry const> const &d_y, Span<bst_row_t const> const &y_ptr,
Span<SketchEntry> out, Span<bst_row_t> out_ptr) {
auto x_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple(
dh::MakeTransformIterator<bst_row_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(x_ptr, idx); }),
d_x.data()));
auto y_merge_key_it = thrust::make_zip_iterator(thrust::make_tuple(
dh::MakeTransformIterator<bst_row_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(y_ptr, idx); }),
d_y.data()));
using Tuple = thrust::tuple<uint64_t, uint64_t>;
thrust::constant_iterator<uint64_t> a_ind_iter(0ul);
thrust::constant_iterator<uint64_t> b_ind_iter(1ul);
auto place_holder = thrust::make_constant_iterator<uint64_t>(0u);
auto x_merge_val_it =
thrust::make_zip_iterator(thrust::make_tuple(a_ind_iter, place_holder));
auto y_merge_val_it =
thrust::make_zip_iterator(thrust::make_tuple(b_ind_iter, place_holder));
dh::XGBCachingDeviceAllocator<Tuple> alloc;
static_assert(sizeof(Tuple) == sizeof(SketchEntry), "");
// We reuse the memory for storing merge path.
common::Span<Tuple> merge_path{reinterpret_cast<Tuple *>(out.data()), out.size()};
// Determine the merge path, 0 if element is from x, 1 if it's from y.
thrust::merge_by_key(
thrust::cuda::par(alloc), x_merge_key_it, x_merge_key_it + d_x.size(),
y_merge_key_it, y_merge_key_it + d_y.size(), x_merge_val_it,
y_merge_val_it, thrust::make_discard_iterator(), merge_path.data(),
[=] __device__(auto const &l, auto const &r) -> bool {
auto l_column_id = thrust::get<0>(l);
auto r_column_id = thrust::get<0>(r);
if (l_column_id == r_column_id) {
return thrust::get<1>(l).value < thrust::get<1>(r).value;
}
return l_column_id < r_column_id;
});
// Compute output ptr
auto transform_it =
thrust::make_zip_iterator(thrust::make_tuple(x_ptr.data(), y_ptr.data()));
thrust::transform(
thrust::cuda::par(alloc), transform_it, transform_it + x_ptr.size(),
out_ptr.data(),
[] __device__(auto const& t) { return thrust::get<0>(t) + thrust::get<1>(t); });
// 0^th is the indicator, 1^th is placeholder
auto get_ind = []XGBOOST_DEVICE(Tuple const& t) { return thrust::get<0>(t); };
// 0^th is the counter for x, 1^th for y.
auto get_x = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<0>(t); };
auto get_y = []XGBOOST_DEVICE(Tuple const &t) { return thrust::get<1>(t); };
auto scan_key_it = dh::MakeTransformIterator<size_t>(
thrust::make_counting_iterator(0ul),
[=] __device__(size_t idx) { return dh::SegmentId(out_ptr, idx); });
auto scan_val_it = dh::MakeTransformIterator<Tuple>(
merge_path.data(), [=] __device__(Tuple const &t) -> Tuple {
auto ind = get_ind(t); // == 0 if element is from x
// x_counter, y_counter
return thrust::make_tuple<uint64_t, uint64_t>(!ind, ind);
});
// Compute the index for both x and y (which of the element in a and b are used in each
// comparison) by scaning the binary merge path. Take output [(x_0, y_0), (x_0, y_1),
// ...] as an example, the comparison between (x_0, y_0) adds 1 step in the merge path.
// Asumming y_0 is less than x_0 so this step is torward the end of y. After the
// comparison, index of y is incremented by 1 from y_0 to y_1, and at the same time, y_0
// is landed into output as the first element in merge result. The scan result is the
// subscript of x and y.
thrust::exclusive_scan_by_key(
thrust::cuda::par(alloc), scan_key_it, scan_key_it + merge_path.size(),
scan_val_it, merge_path.data(),
thrust::make_tuple<uint64_t, uint64_t>(0ul, 0ul),
thrust::equal_to<size_t>{},
[=] __device__(Tuple const &l, Tuple const &r) -> Tuple {
return thrust::make_tuple(get_x(l) + get_x(r), get_y(l) + get_y(r));
});
return merge_path;
}
// Merge d_x and d_y into out. Because the final output depends on predicate (which
// summary does the output element come from) result by definition of merged rank. So we
// run it in 2 passes to obtain the merge path and then customize the standard merge
// algorithm.
void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
Span<bst_row_t const> const &x_ptr,
Span<SketchEntry const> const &d_y,
Span<bst_row_t const> const &y_ptr,
Span<SketchEntry> out,
Span<bst_row_t> out_ptr) {
dh::safe_cuda(cudaSetDevice(device));
CHECK_EQ(d_x.size() + d_y.size(), out.size());
CHECK_EQ(x_ptr.size(), out_ptr.size());
CHECK_EQ(y_ptr.size(), out_ptr.size());
auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr);
auto d_out = out;
dh::LaunchN(device, d_out.size(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(out_ptr, idx);
idx -= out_ptr[column_id];
auto d_x_column =
d_x.subspan(x_ptr[column_id], x_ptr[column_id + 1] - x_ptr[column_id]);
auto d_y_column =
d_y.subspan(y_ptr[column_id], y_ptr[column_id + 1] - y_ptr[column_id]);
auto d_out_column = d_out.subspan(
out_ptr[column_id], out_ptr[column_id + 1] - out_ptr[column_id]);
auto d_path_column = d_merge_path.subspan(
out_ptr[column_id], out_ptr[column_id + 1] - out_ptr[column_id]);
uint64_t a_ind, b_ind;
thrust::tie(a_ind, b_ind) = d_path_column[idx];
// Handle empty column. If both columns are empty, we should not get this column_id
// as result of binary search.
assert((d_x_column.size() != 0) || (d_y_column.size() != 0));
if (d_x_column.size() == 0) {
d_out_column[idx] = d_y_column[b_ind];
return;
}
if (d_y_column.size() == 0) {
d_out_column[idx] = d_x_column[a_ind];
return;
}
// Handle trailing elements.
assert(a_ind <= d_x_column.size());
if (a_ind == d_x_column.size()) {
// Trailing elements are from y because there's no more x to land.
auto y_elem = d_y_column[b_ind];
d_out_column[idx] = SketchEntry(y_elem.rmin + d_x_column.back().RMinNext(),
y_elem.rmax + d_x_column.back().rmax,
y_elem.wmin, y_elem.value);
return;
}
auto x_elem = d_x_column[a_ind];
assert(b_ind <= d_y_column.size());
if (b_ind == d_y_column.size()) {
d_out_column[idx] = SketchEntry(x_elem.rmin + d_y_column.back().RMinNext(),
x_elem.rmax + d_y_column.back().rmax,
x_elem.wmin, x_elem.value);
return;
}
auto y_elem = d_y_column[b_ind];
/* Merge procedure. See A.3 merge operation eq (26) ~ (28). The trick to interpret
it is rewriting the symbols on both side of equality. Take eq (26) as an example:
Expand it according to definition of extended rank then rewrite it into:
If $k_i$ is the $i$ element in output and \textbf{comes from $D_1$}:
r_\bar{D}(k_i) = r_{\bar{D_1}}(k_i) + w_{\bar{{D_1}}}(k_i) +
[r_{\bar{D_2}}(x_i) + w_{\bar{D_2}}(x_i)]
Where $x_i$ is the largest element in $D_2$ that's less than $k_i$. $k_i$ can be
used in $D_1$ as it's since $k_i \in D_1$. Other 2 equations can be applied
similarly with $k_i$ comes from different $D$. just use different symbol on
different source of summary.
*/
assert(idx < d_out_column.size());
if (x_elem.value == y_elem.value) {
d_out_column[idx] =
SketchEntry{x_elem.rmin + y_elem.rmin, x_elem.rmax + y_elem.rmax,
x_elem.wmin + y_elem.wmin, x_elem.value};
} else if (x_elem.value < y_elem.value) {
// elem from x is landed. yprev_min is the element in D_2 that's 1 rank less than
// x_elem if we put x_elem in D_2.
float yprev_min = b_ind == 0 ? 0.0f : d_y_column[b_ind - 1].RMinNext();
// rmin should be equal to x_elem.rmin + x_elem.wmin + yprev_min. But for
// implementation, the weight is stored in a separated field and we compute the
// extended definition on the fly when needed.
d_out_column[idx] =
SketchEntry{x_elem.rmin + yprev_min, x_elem.rmax + y_elem.RMaxPrev(),
x_elem.wmin, x_elem.value};
} else {
// elem from y is landed.
float xprev_min = a_ind == 0 ? 0.0f : d_x_column[a_ind - 1].RMinNext();
d_out_column[idx] =
SketchEntry{xprev_min + y_elem.rmin, x_elem.RMaxPrev() + y_elem.rmax,
y_elem.wmin, y_elem.value};
}
});
}
void SketchContainer::Push(common::Span<OffsetT const> cuts_ptr,
dh::caching_device_vector<SketchEntry>* entries) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
// Copy or merge the new cuts, pruning is performed during `MakeCuts`.
if (this->Current().size() == 0) {
CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size());
// See thrust issue 1030, THRUST_CPP_DIALECT is not correctly defined so
// move constructor is not used.
this->Current().swap(*entries);
CHECK_EQ(entries->size(), 0);
auto d_cuts_ptr = this->columns_ptr_.DevicePointer();
thrust::copy(thrust::device, cuts_ptr.data(),
cuts_ptr.data() + cuts_ptr.size(), d_cuts_ptr);
} else {
auto d_entries = dh::ToSpan(*entries);
this->Merge(cuts_ptr, d_entries);
this->FixError();
}
CHECK_NE(this->columns_ptr_.Size(), 0);
timer_.Stop(__func__);
}
size_t SketchContainer::Unique() {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();
d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(),
detail::SketchUnique{});
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());
this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
}
void SketchContainer::Prune(size_t to) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->Unique();
OffsetT to_total = 0;
HostDeviceVector<OffsetT> new_columns_ptr{to_total};
for (bst_feature_t i = 0; i < num_columns_; ++i) {
size_t length = this->Column(i).size();
length = std::min(length, to);
to_total += length;
new_columns_ptr.HostVector().emplace_back(to_total);
}
new_columns_ptr.SetDevice(device_);
this->Other().resize(to_total);
auto d_columns_ptr_in = this->columns_ptr_.ConstDeviceSpan();
auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan();
auto out = dh::ToSpan(this->Other());
auto in = dh::ToSpan(this->Current());
dh::LaunchN(0, to_total, [=] __device__(size_t idx) {
size_t column_id = dh::SegmentId(d_columns_ptr_out, idx);
auto out_column = out.subspan(d_columns_ptr_out[column_id],
d_columns_ptr_out[column_id + 1] -
d_columns_ptr_out[column_id]);
auto in_column = in.subspan(d_columns_ptr_in[column_id],
d_columns_ptr_in[column_id + 1] -
d_columns_ptr_in[column_id]);
idx -= d_columns_ptr_out[column_id];
// Input has lesser columns than `to`, just copy them to the output. This is correct
// as the new output size is calculated based on both the size of `to` and current
// column.
if (in_column.size() <= to) {
out_column[idx] = in_column[idx];
return;
}
// 1 thread for each output. See A.4 for detail.
auto entries = in_column;
auto d_out = out_column;
if (idx == 0) {
d_out.front() = entries.front();
return;
}
if (idx == to - 1) {
d_out.back() = entries.back();
return;
}
float w = entries.back().rmin - entries.front().rmax;
assert(w != 0);
auto budget = static_cast<float>(d_out.size());
assert(budget != 0);
auto q = ((idx * w) / (to - 1) + entries.front().rmax);
d_out[idx] = BinarySearchQuery(entries, q);
});
this->columns_ptr_.HostVector() = new_columns_ptr.HostVector();
this->Alternate();
timer_.Stop(__func__);
}
void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
Span<SketchEntry const> that) {
dh::safe_cuda(cudaSetDevice(device_));
timer_.Start(__func__);
if (this->Current().size() == 0) {
CHECK_EQ(this->columns_ptr_.HostVector().back(), 0);
CHECK_EQ(this->columns_ptr_.HostVector().size(), d_that_columns_ptr.size());
CHECK_EQ(columns_ptr_.Size(), num_columns_ + 1);
thrust::copy(thrust::device, d_that_columns_ptr.data(),
d_that_columns_ptr.data() + d_that_columns_ptr.size(),
this->columns_ptr_.DevicePointer());
auto total = this->columns_ptr_.HostVector().back();
this->Current().resize(total);
CopyTo(dh::ToSpan(this->Current()), that);
timer_.Stop(__func__);
return;
}
this->Other().resize(this->Current().size() + that.size());
CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size());
HostDeviceVector<OffsetT> new_columns_ptr;
new_columns_ptr.SetDevice(device_);
new_columns_ptr.Resize(this->ColumnsPtr().size());
MergeImpl(device_, this->Data(), this->ColumnsPtr(),
that, d_that_columns_ptr,
dh::ToSpan(this->Other()), new_columns_ptr.DeviceSpan());
this->columns_ptr_ = std::move(new_columns_ptr);
CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1);
CHECK_EQ(new_columns_ptr.Size(), 0);
this->Alternate();
timer_.Stop(__func__);
}
void SketchContainer::FixError() {
dh::safe_cuda(cudaSetDevice(device_));
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
auto in = dh::ToSpan(this->Current());
dh::LaunchN(device_, in.size(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_columns_ptr, idx);
auto in_column = in.subspan(d_columns_ptr[column_id],
d_columns_ptr[column_id + 1] -
d_columns_ptr[column_id]);
idx -= d_columns_ptr[column_id];
float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin;
if (in_column[idx].rmin < prev_rmin) {
in_column[idx].rmin = prev_rmin;
}
float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax;
if (in_column[idx].rmax < prev_rmax) {
in_column[idx].rmax = prev_rmax;
}
float rmin_next = in_column[idx].RMinNext();
if (in_column[idx].rmax < rmin_next) {
in_column[idx].rmax = rmin_next;
}
});
}
void SketchContainer::AllReduce() {
dh::safe_cuda(cudaSetDevice(device_));
auto world = rabit::GetWorldSize();
if (world == 1) {
return;
}
timer_.Start(__func__);
if (!reducer_) {
reducer_ = std::make_unique<dh::AllReducer>();
reducer_->Init(device_);
}
// Reduce the overhead on syncing.
size_t global_sum_rows = num_rows_;
rabit::Allreduce<rabit::op::Sum>(&global_sum_rows, 1);
size_t intermediate_num_cuts =
std::min(global_sum_rows, static_cast<size_t>(num_bins_ * kFactor));
this->Prune(intermediate_num_cuts);
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
CHECK_EQ(d_columns_ptr.size(), num_columns_ + 1);
size_t n = d_columns_ptr.size();
rabit::Allreduce<rabit::op::Max>(&n, 1);
CHECK_EQ(n, d_columns_ptr.size()) << "Number of columns differs across workers";
// Get the columns ptr from all workers
dh::device_vector<SketchContainer::OffsetT> gathered_ptrs;
gathered_ptrs.resize(d_columns_ptr.size() * world, 0);
size_t rank = rabit::GetRank();
auto offset = rank * d_columns_ptr.size();
thrust::copy(thrust::device, d_columns_ptr.data(), d_columns_ptr.data() + d_columns_ptr.size(),
gathered_ptrs.begin() + offset);
reducer_->AllReduceSum(gathered_ptrs.data().get(), gathered_ptrs.data().get(),
gathered_ptrs.size());
// Get the data from all workers.
std::vector<size_t> recv_lengths;
dh::caching_device_vector<char> recvbuf;
reducer_->AllGather(this->Current().data().get(),
dh::ToSpan(this->Current()).size_bytes(), &recv_lengths,
&recvbuf);
reducer_->Synchronize();
// Segment the received data.
auto s_recvbuf = dh::ToSpan(recvbuf);
std::vector<Span<SketchEntry>> allworkers;
offset = 0;
for (int32_t i = 0; i < world; ++i) {
size_t length_as_bytes = recv_lengths.at(i);
auto raw = s_recvbuf.subspan(offset, length_as_bytes);
auto sketch = Span<SketchEntry>(reinterpret_cast<SketchEntry *>(raw.data()),
length_as_bytes / sizeof(SketchEntry));
allworkers.emplace_back(sketch);
offset += length_as_bytes;
}
// Merge them into a new sketch.
SketchContainer new_sketch(num_bins_, this->num_columns_, global_sum_rows,
this->device_);
for (size_t i = 0; i < allworkers.size(); ++i) {
auto worker = allworkers[i];
auto worker_ptr =
dh::ToSpan(gathered_ptrs)
.subspan(i * d_columns_ptr.size(), d_columns_ptr.size());
new_sketch.Merge(worker_ptr, worker);
new_sketch.FixError();
}
*this = std::move(new_sketch);
timer_.Stop(__func__);
}
void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
p_cuts->min_vals_.Resize(num_columns_);
// Sync between workers.
this->AllReduce();
// Prune to final number of bins.
this->Prune(num_bins_ + 1);
this->Unique();
this->FixError();
// Set up inputs
auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
p_cuts->min_vals_.SetDevice(device_);
auto d_min_values = p_cuts->min_vals_.DeviceSpan();
auto in_cut_values = dh::ToSpan(this->Current());
// Set up output ptr
p_cuts->cut_ptrs_.SetDevice(device_);
auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector();
h_out_columns_ptr.clear();
h_out_columns_ptr.push_back(0);
for (bst_feature_t i = 0; i < num_columns_; ++i) {
h_out_columns_ptr.push_back(
std::min(static_cast<size_t>(std::max(static_cast<size_t>(1ul),
this->Column(i).size())),
static_cast<size_t>(num_bins_)));
}
std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(),
h_out_columns_ptr.begin());
auto d_out_columns_ptr = p_cuts->cut_ptrs_.ConstDeviceSpan();
// Set up output cuts
size_t total_bins = h_out_columns_ptr.back();
p_cuts->cut_values_.SetDevice(device_);
p_cuts->cut_values_.Resize(total_bins);
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
dh::LaunchN(0, total_bins, [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
auto in_column = in_cut_values.subspan(d_in_columns_ptr[column_id],
d_in_columns_ptr[column_id + 1] -
d_in_columns_ptr[column_id]);
auto out_column = out_cut_values.subspan(d_out_columns_ptr[column_id],
d_out_columns_ptr[column_id + 1] -
d_out_columns_ptr[column_id]);
idx -= d_out_columns_ptr[column_id];
if (in_column.size() == 0) {
// If the column is empty, we push a dummy value. It won't affect training as the
// column is empty, trees cannot split on it. This is just to be consistent with
// rest of the library.
if (idx == 0) {
d_min_values[column_id] = kRtEps;
out_column[0] = kRtEps;
assert(out_column.size() == 1);
}
return;
}
// First thread is responsible for setting min values.
if (idx == 0) {
auto mval = in_column[idx].value;
d_min_values[column_id] = mval - (fabs(mval) + 1e-5);
}
// Last thread is responsible for setting a value that's greater than other cuts.
if (idx == out_column.size() - 1) {
const bst_float cpt = in_column.back().value;
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5);
out_column[idx] = last;
return;
}
assert(idx+1 < in_column.size());
out_column[idx] = in_column[idx+1].value;
});
timer_.Stop(__func__);
}
} // namespace common
} // namespace xgboost

141
src/common/quantile.cuh Normal file
View File

@ -0,0 +1,141 @@
#ifndef XGBOOST_COMMON_QUANTILE_CUH_
#define XGBOOST_COMMON_QUANTILE_CUH_
#include <memory>
#include "xgboost/span.h"
#include "device_helpers.cuh"
#include "quantile.h"
#include "timer.h"
namespace xgboost {
namespace common {
class HistogramCuts;
using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry;
/*!
* \brief A container that holds the device sketches. Sketching is performed per-column,
* but fused into single operation for performance.
*/
class SketchContainer {
public:
static constexpr float kFactor = WQSketch::kFactor;
using OffsetT = bst_row_t;
static_assert(sizeof(OffsetT) == sizeof(size_t), "Wrong type for sketch element offset.");
private:
Monitor timer_;
std::unique_ptr<dh::AllReducer> reducer_;
bst_row_t num_rows_;
bst_feature_t num_columns_;
int32_t num_bins_;
int32_t device_;
// Double buffer as neither prune nor merge can be performed inplace.
dh::caching_device_vector<SketchEntry> entries_a_;
dh::caching_device_vector<SketchEntry> entries_b_;
bool current_buffer_ {true};
// The container is just a CSC matrix.
HostDeviceVector<OffsetT> columns_ptr_;
dh::caching_device_vector<SketchEntry>& Current() {
if (current_buffer_) {
return entries_a_;
} else {
return entries_b_;
}
}
dh::caching_device_vector<SketchEntry>& Other() {
if (!current_buffer_) {
return entries_a_;
} else {
return entries_b_;
}
}
dh::caching_device_vector<SketchEntry> const& Current() const {
return const_cast<SketchContainer*>(this)->Current();
}
dh::caching_device_vector<SketchEntry> const& Other() const {
return const_cast<SketchContainer*>(this)->Other();
}
void Alternate() {
current_buffer_ = !current_buffer_;
}
// Get the span of one column.
Span<SketchEntry> Column(bst_feature_t i) {
auto data = dh::ToSpan(this->Current());
auto h_ptr = columns_ptr_.ConstHostSpan();
auto c = data.subspan(h_ptr[i], h_ptr[i+1] - h_ptr[i]);
return c;
}
public:
/* \breif GPU quantile structure, with sketch data for each columns.
*
* \param max_bin Maximum number of bins per columns
* \param num_columns Total number of columns in dataset.
* \param num_rows Total number of rows in known dataset (typically the rows in current worker).
* \param device GPU ID.
*/
SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) :
num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} {
// Initialize Sketches for this dmatrix
this->columns_ptr_.SetDevice(device_);
this->columns_ptr_.Resize(num_columns + 1);
timer_.Init(__func__);
}
/* \brief Return GPU ID for this container. */
int32_t DeviceIdx() const { return device_; }
/* \brief Removes all the duplicated elements in quantile structure. */
size_t Unique();
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
void FixError();
/* \brief Push a CSC structured cut matrix. */
void Push(common::Span<OffsetT const> cuts_ptr,
dh::caching_device_vector<SketchEntry>* entries);
/* \brief Prune the quantile structure.
*
* \param to The maximum size of pruned quantile. If the size of quantile structure is
* already less than `to`, then no operation is performed.
*/
void Prune(size_t to);
/* \brief Merge another set of sketch.
* \param that columns of other.
*/
void Merge(Span<OffsetT const> that_columns_ptr,
Span<SketchEntry const> that);
/* \brief Merge quantiles from other GPU workers. */
void AllReduce();
/* \brief Create the final histogram cut values. */
void MakeCuts(HistogramCuts* cuts);
Span<SketchEntry const> Data() const {
return {this->Current().data().get(), this->Current().size()};
}
Span<OffsetT const> ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); }
SketchContainer(SketchContainer&&) = default;
SketchContainer& operator=(SketchContainer&&) = default;
SketchContainer(const SketchContainer&) = delete;
SketchContainer& operator=(const SketchContainer&) = delete;
};
namespace detail {
struct SketchUnique {
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
}
};
} // anonymous detail
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_QUANTILE_CUH_

View File

@ -55,6 +55,14 @@ struct WQSummary {
XGBOOST_DEVICE inline RType RMaxPrev() const { XGBOOST_DEVICE inline RType RMaxPrev() const {
return rmax - wmin; return rmax - wmin;
} }
friend std::ostream& operator<<(std::ostream& os, Entry const& e) {
os << "rmin: " << e.rmin << ", "
<< "rmax: " << e.rmax << ", "
<< "wmin: " << e.wmin << ", "
<< "value: " << e.value;
return os;
}
}; };
/*! \brief input data queue before entering the summary */ /*! \brief input data queue before entering the summary */
struct Queue { struct Queue {
@ -184,14 +192,14 @@ struct WQSummary {
} }
} }
} }
/*! /*!
* \brief set current summary to be pruned summary of src * \brief set current summary to be pruned summary of src
* assume data field is already allocated to be at least maxsize * assume data field is already allocated to be at least maxsize
* \param src source summary * \param src source summary
* \param maxsize size we can afford in the pruned sketch * \param maxsize size we can afford in the pruned sketch
*/ */
void SetPrune(const WQSummary &src, size_t maxsize) {
inline void SetPrune(const WQSummary &src, size_t maxsize) {
if (src.size <= maxsize) { if (src.size <= maxsize) {
this->CopyFrom(src); return; this->CopyFrom(src); return;
} }
@ -454,6 +462,9 @@ struct WXQSummary : public WQSummary<DType, RType> {
*/ */
template<typename DType, typename RType, class TSummary> template<typename DType, typename RType, class TSummary>
class QuantileSketchTemplate { class QuantileSketchTemplate {
public:
static float constexpr kFactor = 8.0;
public: public:
/*! \brief type of summary type */ /*! \brief type of summary type */
using Summary = TSummary; using Summary = TSummary;

0
src/common/threading_utils.h Executable file → Normal file
View File

View File

@ -57,80 +57,52 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin
size_t nnz = 0; size_t nnz = 0;
// Sketch for all batches. // Sketch for all batches.
iter.Reset(); iter.Reset();
common::HistogramCuts cuts;
common::DenseCuts dense_cuts(&cuts);
std::vector<common::SketchContainer> sketch_containers; std::vector<common::SketchContainer> sketch_containers;
size_t batches = 0; size_t batches = 0;
size_t accumulated_rows = 0; size_t accumulated_rows = 0;
bst_feature_t cols = 0; bst_feature_t cols = 0;
int32_t device = -1;
while (iter.Next()) { while (iter.Next()) {
auto device = proxy->DeviceIdx(); device = proxy->DeviceIdx();
dh::safe_cuda(cudaSetDevice(device)); dh::safe_cuda(cudaSetDevice(device));
if (cols == 0) { if (cols == 0) {
cols = num_cols(); cols = num_cols();
} else { } else {
CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns.";
} }
sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows()); sketch_containers.emplace_back(batch_param_.max_bin, num_cols(), num_rows(), device);
auto* p_sketch = &sketch_containers.back(); auto* p_sketch = &sketch_containers.back();
if (proxy->Info().weights_.Size() != 0) {
proxy->Info().weights_.SetDevice(device); proxy->Info().weights_.SetDevice(device);
Dispatch(proxy, [&](auto const &value) { Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin, common::AdapterDeviceSketchWeighted(value, batch_param_.max_bin,
proxy->Info(), proxy->Info(), missing, p_sketch);
missing, device, p_sketch);
}); });
} else {
Dispatch(proxy, [&](auto const &value) {
common::AdapterDeviceSketch(value, batch_param_.max_bin, missing,
device, p_sketch);
});
}
auto batch_rows = num_rows(); auto batch_rows = num_rows();
accumulated_rows += batch_rows; accumulated_rows += batch_rows;
dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0); dh::caching_device_vector<size_t> row_counts(batch_rows + 1, 0);
common::Span<size_t> row_counts_span(row_counts.data().get(), common::Span<size_t> row_counts_span(row_counts.data().get(),
row_counts.size()); row_counts.size());
row_stride = row_stride = std::max(row_stride, Dispatch(proxy, [=](auto const &value) {
std::max(row_stride, Dispatch(proxy, [=](auto const& value) { return GetRowCounts(value, row_counts_span,
return GetRowCounts(value, row_counts_span, device, missing); device, missing);
})); }));
nnz += thrust::reduce(thrust::cuda::par(alloc), nnz += thrust::reduce(thrust::cuda::par(alloc), row_counts.begin(),
row_counts.begin(), row_counts.end()); row_counts.end());
batches++; batches++;
} }
// Merging multiple batches for each column common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, device);
std::vector<common::WQSketch::SummaryContainer> summary_array(cols); for (auto const& sketch : sketch_containers) {
size_t intermediate_num_cuts = std::min( final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data());
accumulated_rows, static_cast<size_t>(batch_param_.max_bin * final_sketch.FixError();
common::SketchContainer::kFactor));
size_t nbytes =
common::WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
for (omp_ulong c = 0; c < cols; ++c) {
for (auto& sketch_batch : sketch_containers) {
common::WQSketch::SummaryContainer summary;
sketch_batch.sketches_.at(c).GetSummary(&summary);
sketch_batch.sketches_.at(c).Init(0, 1);
summary_array.at(c).Reduce(summary, nbytes);
}
} }
sketch_containers.clear(); sketch_containers.clear();
sketch_containers.shrink_to_fit();
// Build the final summary. common::HistogramCuts cuts;
std::vector<common::WQSketch> sketches(cols); final_sketch.MakeCuts(&cuts);
#pragma omp parallel for num_threads(nthread) if (nthread > 0)
for (omp_ulong c = 0; c < cols; ++c) {
sketches.at(c).Init(
accumulated_rows,
1.0 / (common::SketchContainer::kFactor * batch_param_.max_bin));
sketches.at(c).PushSummary(summary_array.at(c));
}
dense_cuts.Init(&sketches, batch_param_.max_bin, accumulated_rows);
summary_array.clear();
this->info_.num_col_ = cols; this->info_.num_col_ = cols;
this->info_.num_row_ = accumulated_rows; this->info_.num_row_ = accumulated_rows;

View File

@ -5,6 +5,7 @@
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <xgboost/base.h> #include <xgboost/base.h>
#include "../../../src/common/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
#include "../../../src/common/quantile.h"
#include "../helpers.h" #include "../helpers.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
@ -14,3 +15,128 @@ TEST(SumReduce, Test) {
ASSERT_NEAR(sum, 100.0f, 1e-5); ASSERT_NEAR(sum, 100.0f, 1e-5);
} }
void TestAtomicSizeT() {
size_t constexpr kThreads = 235;
dh::device_vector<size_t> out(1, 0);
auto d_out = dh::ToSpan(out);
dh::LaunchN(0, kThreads, [=]__device__(size_t idx){
atomicAdd(&d_out[0], static_cast<size_t>(1));
});
ASSERT_EQ(out[0], kThreads);
}
TEST(AtomicAdd, SizeT) {
TestAtomicSizeT();
}
TEST(SegmentedUnique, Basic) {
std::vector<float> values{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.62448811531066895f, 0.4f};
std::vector<size_t> segments{0, 3, 6};
thrust::device_vector<float> d_values(values);
thrust::device_vector<xgboost::bst_feature_t> d_segments{segments};
thrust::device_vector<xgboost::bst_feature_t> d_segs_out(d_segments.size());
thrust::device_vector<float> d_vals_out(d_values.size());
size_t n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(),
d_segs_out.data().get(), d_vals_out.data().get(),
thrust::equal_to<float>{});
CHECK_EQ(n_uniques, 5);
std::vector<float> values_sol{0.1f, 0.2f, 0.3f, 0.62448811531066895f, 0.4f};
for (auto i = 0 ; i < values_sol.size(); i ++) {
ASSERT_EQ(d_vals_out[i], values_sol[i]);
}
std::vector<xgboost::bst_feature_t> segments_sol{0, 3, 5};
for (size_t i = 0; i < d_segments.size(); ++i) {
ASSERT_EQ(segments_sol[i], d_segs_out[i]);
}
d_segments[1] = 4;
d_segments[2] = 6;
n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(),
d_values.data().get(), d_values.data().get() + d_values.size(),
d_segs_out.data().get(), d_vals_out.data().get(),
thrust::equal_to<float>{});
ASSERT_EQ(n_uniques, values.size());
for (auto i = 0 ; i < values.size(); i ++) {
ASSERT_EQ(d_vals_out[i], values[i]);
}
}
namespace {
using SketchEntry = xgboost::common::WQSummary<float, float>::Entry;
struct SketchUnique {
bool __device__ operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
}
};
struct IsSorted {
bool __device__ operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value < b.value;
}
};
} // namespace
namespace xgboost {
namespace common {
void TestSegmentedUniqueRegression(std::vector<SketchEntry> values, size_t n_duplicated) {
std::vector<bst_feature_t> segments{0, static_cast<bst_feature_t>(values.size())};
thrust::device_vector<SketchEntry> d_values(values);
thrust::device_vector<bst_feature_t> d_segments(segments);
thrust::device_vector<bst_feature_t> d_segments_out(segments.size());
size_t n_uniques = dh::SegmentedUnique(
d_segments.data().get(), d_segments.data().get() + d_segments.size(), d_values.data().get(),
d_values.data().get() + d_values.size(), d_segments_out.data().get(), d_values.data().get(),
SketchUnique{});
ASSERT_EQ(n_uniques, values.size() - n_duplicated);
ASSERT_TRUE(thrust::is_sorted(thrust::device, d_values.begin(),
d_values.begin() + n_uniques, IsSorted{}));
ASSERT_EQ(segments.at(0), d_segments_out[0]);
ASSERT_EQ(segments.at(1), d_segments_out[1] + n_duplicated);
}
TEST(SegmentedUnique, Regression) {
{
std::vector<SketchEntry> values{{3149, 3150, 1, 0.62392902374267578},
{3151, 3152, 1, 0.62418866157531738},
{3152, 3153, 1, 0.62419462203979492},
{3153, 3154, 1, 0.62431186437606812},
{3154, 3155, 1, 0.6244881153106689453125},
{3155, 3156, 1, 0.6244881153106689453125},
{3155, 3156, 1, 0.6244881153106689453125},
{3155, 3156, 1, 0.6244881153106689453125},
{3157, 3158, 1, 0.62552797794342041},
{3158, 3159, 1, 0.6256556510925293},
{3159, 3160, 1, 0.62571090459823608},
{3160, 3161, 1, 0.62577134370803833}};
TestSegmentedUniqueRegression(values, 3);
}
{
std::vector<SketchEntry> values{{3149, 3150, 1, 0.62392902374267578},
{3151, 3152, 1, 0.62418866157531738},
{3152, 3153, 1, 0.62419462203979492},
{3153, 3154, 1, 0.62431186437606812},
{3154, 3155, 1, 0.6244881153106689453125},
{3157, 3158, 1, 0.62552797794342041},
{3158, 3159, 1, 0.6256556510925293},
{3159, 3160, 1, 0.62571090459823608},
{3160, 3161, 1, 0.62577134370803833}};
TestSegmentedUniqueRegression(values, 0);
}
{
std::vector<SketchEntry> values;
TestSegmentedUniqueRegression(values, 0);
}
}
} // namespace common
} // namespace xgboost

View File

@ -30,11 +30,12 @@ HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) {
builder.Build(&dmat, num_bins); builder.Build(&dmat, num_bins);
return cuts; return cuts;
} }
TEST(HistUtil, DeviceSketch) { TEST(HistUtil, DeviceSketch) {
int num_rows = 5;
int num_columns = 1; int num_columns = 1;
int num_bins = 4; int num_bins = 4;
std::vector<float> x = {1.0, 2.0, 3.0, 4.0, 5.0}; std::vector<float> x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 7.0f, -1.0f};
int num_rows = x.size();
auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns);
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
@ -47,26 +48,6 @@ TEST(HistUtil, DeviceSketch) {
EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues());
} }
// Duplicate this function from hist_util.cu so we don't have to expose it in
// header
size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) {
double eps = 1.0 / (SketchContainer::kFactor * max_bins);
size_t dummy_nlevel;
size_t num_cuts;
WQuantileSketch<bst_float, bst_float>::LimitSizeLevel(
num_rows, eps, &dummy_nlevel, &num_cuts);
return std::min(num_cuts, num_rows);
}
size_t BytesRequiredForTest(size_t num_rows, size_t num_columns, size_t num_bins,
bool with_weights) {
size_t bytes_num_elements = BytesPerElement(with_weights) * num_rows * num_columns;
size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns *
sizeof(DenseCuts::WQSketch::Entry);
// divide by 2 is because the memory quota used in sorting is reused for storing cuts.
return bytes_num_elements / 2 + bytes_cuts;
}
TEST(HistUtil, DeviceSketchMemory) { TEST(HistUtil, DeviceSketchMemory) {
int num_columns = 100; int num_columns = 100;
int num_rows = 1000; int num_rows = 1000;
@ -77,15 +58,15 @@ TEST(HistUtil, DeviceSketchMemory) {
dh::GlobalMemoryLogger().Clear(); dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}}); ConsoleLogger::Configure({{"verbosity", "3"}});
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); size_t bytes_required = detail::RequiredMemory(
size_t bytes_constant = 1000; num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
ConsoleLogger::Configure({{"verbosity", "0"}});
} }
TEST(HistUtil, DeviceSketchMemoryWeights) { TEST(HistUtil, DeviceSketchWeightsMemory) {
int num_columns = 100; int num_columns = 100;
int num_rows = 1000; int num_rows = 1000;
int num_bins = 256; int num_bins = 256;
@ -98,7 +79,8 @@ TEST(HistUtil, DeviceSketchMemoryWeights) {
auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins);
ConsoleLogger::Configure({{"verbosity", "0"}}); ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true); size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, true);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
} }
@ -118,7 +100,7 @@ TEST(HistUtil, DeviceSketchDeterminism) {
} }
} }
TEST(HistUtil, DeviceSketchCategorical) { TEST(HistUtil, DeviceSketchCategorical) {
int categorical_sizes[] = {2, 6, 8, 12}; int categorical_sizes[] = {2, 6, 8, 12};
int num_bins = 256; int num_bins = 256;
int sizes[] = {25, 100, 1000}; int sizes[] = {25, 100, 1000};
@ -231,11 +213,10 @@ template <typename Adapter>
void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows, void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows,
DMatrix* dmat) { DMatrix* dmat) {
common::HistogramCuts batched_cuts; common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows); SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(), AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
0, &sketch_container); &sketch_container);
common::DenseCuts dense_cuts(&batched_cuts); sketch_container.MakeCuts(&batched_cuts);
dense_cuts.Init(&sketch_container.sketches_, num_bins, num_rows);
ValidateCuts(batched_cuts, dmat, num_bins); ValidateCuts(batched_cuts, dmat, num_bins);
} }
@ -275,12 +256,13 @@ TEST(HistUtil, AdapterDeviceSketchMemory) {
std::numeric_limits<float>::quiet_NaN()); std::numeric_limits<float>::quiet_NaN());
ConsoleLogger::Configure({{"verbosity", "0"}}); ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_constant = 1000; size_t bytes_constant = 1000;
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
} }
TEST(HistUtil, AdapterSketchBatchMemory) { TEST(HistUtil, AdapterSketchSlidingWindowMemory) {
int num_columns = 100; int num_columns = 100;
int num_rows = 1000; int num_rows = 1000;
int num_bins = 256; int num_bins = 256;
@ -291,17 +273,19 @@ TEST(HistUtil, AdapterSketchBatchMemory) {
dh::GlobalMemoryLogger().Clear(); dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}}); ConsoleLogger::Configure({{"verbosity", "3"}});
common::HistogramCuts batched_cuts; common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows); SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(), AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits<float>::quiet_NaN(),
0, &sketch_container); &sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95);
ConsoleLogger::Configure({{"verbosity", "0"}}); ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_constant = 1000;
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
} }
TEST(HistUtil, AdapterSketchBatchWeightedMemory) { TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) {
int num_columns = 100; int num_columns = 100;
int num_rows = 1000; int num_rows = 1000;
int num_bins = 256; int num_bins = 256;
@ -316,12 +300,15 @@ TEST(HistUtil, AdapterSketchBatchWeightedMemory) {
dh::GlobalMemoryLogger().Clear(); dh::GlobalMemoryLogger().Clear();
ConsoleLogger::Configure({{"verbosity", "3"}}); ConsoleLogger::Configure({{"verbosity", "3"}});
common::HistogramCuts batched_cuts; common::HistogramCuts batched_cuts;
SketchContainer sketch_container(num_bins, num_columns, num_rows); SketchContainer sketch_container(num_bins, num_columns, num_rows, 0);
AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info, AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info,
std::numeric_limits<float>::quiet_NaN(), 0, std::numeric_limits<float>::quiet_NaN(),
&sketch_container); &sketch_container);
HistogramCuts cuts;
sketch_container.MakeCuts(&cuts);
ConsoleLogger::Configure({{"verbosity", "0"}}); ConsoleLogger::Configure({{"verbosity", "0"}});
size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true); size_t bytes_required = detail::RequiredMemory(
num_rows, num_columns, num_rows * num_columns, num_bins, true);
EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05);
EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required);
} }
@ -462,13 +449,11 @@ void TestAdapterSketchFromWeights(bool with_group) {
data::CupyAdapter adapter(m); data::CupyAdapter adapter(m);
auto const& batch = adapter.Value(); auto const& batch = adapter.Value();
SketchContainer sketch_container(kBins, kCols, kRows); SketchContainer sketch_container(kBins, kCols, kRows, 0);
AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(), AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits<float>::quiet_NaN(),
0,
&sketch_container); &sketch_container);
common::HistogramCuts cuts; common::HistogramCuts cuts;
common::DenseCuts dense_cuts(&cuts); sketch_container.MakeCuts(&cuts);
dense_cuts.Init(&sketch_container.sketches_, kBins, kRows);
auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols);
if (with_group) { if (with_group) {

View File

@ -117,7 +117,7 @@ inline void TestBinDistribution(const HistogramCuts &cuts, int column_idx,
// First and last bin can have smaller // First and last bin can have smaller
for (auto& kv : bin_weights) { for (auto& kv : bin_weights) {
EXPECT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight), ASSERT_LE(std::abs(bin_weights[kv.first] - expected_bin_weight),
allowable_error); allowable_error);
} }
} }
@ -189,7 +189,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat,
// Collect data into columns // Collect data into columns
std::vector<std::vector<float>> columns(dmat->Info().num_col_); std::vector<std::vector<float>> columns(dmat->Info().num_col_);
for (auto& batch : dmat->GetBatches<SparsePage>()) { for (auto& batch : dmat->GetBatches<SparsePage>()) {
CHECK_GT(batch.Size(), 0); ASSERT_GT(batch.Size(), 0);
for (auto i = 0ull; i < batch.Size(); i++) { for (auto i = 0ull; i < batch.Size(); i++) {
for (auto e : batch[i]) { for (auto e : batch[i]) {
columns[e.index].push_back(e.fvalue); columns[e.index].push_back(e.fvalue);

0
tests/cpp/common/test_partition_builder.cc Executable file → Normal file
View File

View File

@ -0,0 +1,480 @@
#include <gtest/gtest.h>
#include "../helpers.h"
#include "../../../src/common/hist_util.cuh"
#include "../../../src/common/quantile.cuh"
namespace xgboost {
namespace common {
TEST(GPUQuantile, Basic) {
constexpr size_t kRows = 1000, kCols = 100, kBins = 256;
SketchContainer sketch(kBins, kCols, kRows, 0);
dh::caching_device_vector<SketchEntry> entries;
dh::device_vector<bst_row_t> cuts_ptr(kCols+1);
thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0);
// Push empty
sketch.Push(dh::ToSpan(cuts_ptr), &entries);
ASSERT_EQ(sketch.Data().size(), 0);
}
template <typename Fn> void RunWithSeedsAndBins(size_t rows, Fn fn) {
std::vector<int32_t> seeds(4);
SimpleLCG lcg;
SimpleRealUniformDistribution<float> dist(3, 1000);
std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); });
std::vector<size_t> bins(8);
for (size_t i = 0; i < bins.size() - 1; ++i) {
bins[i] = i * 35 + 2;
}
bins.back() = rows + 80; // provide a bin number greater than rows.
std::vector<MetaInfo> infos(2);
auto& h_weights = infos.front().weights_.HostVector();
h_weights.resize(rows);
std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); });
for (auto seed : seeds) {
for (auto n_bin : bins) {
for (auto const& info : infos) {
fn(seed, n_bin, info);
}
}
}
}
void TestSketchUnique(float sparsity) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) {
SketchContainer sketch(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity}
.Seed(seed)
.Device(0)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch);
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
dh::caching_device_vector<size_t> column_sizes_scan;
HostDeviceVector<size_t> cut_sizes_scan;
auto batch = adapter.Value();
data::IsValidFunctor is_valid(std::numeric_limits<float>::quiet_NaN());
auto batch_iter = dh::MakeTransformIterator<data::COOTuple>(
thrust::make_counting_iterator(0llu),
[=] __device__(size_t idx) { return batch.GetElement(idx); });
auto end = kCols * kRows;
detail::GetColumnSizesScan(0, kCols, n_cuts, batch_iter, is_valid, 0, end,
&cut_sizes_scan, &column_sizes_scan);
auto const& cut_sizes = cut_sizes_scan.HostVector();
if (sparsity == 0) {
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
} else {
ASSERT_EQ(sketch.Data().size(), cut_sizes.back());
}
sketch.Unique();
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
sketch.Data().data() + sketch.Data().size(),
detail::SketchUnique{}));
});
}
TEST(GPUQuantile, Unique) {
TestSketchUnique(0);
TestSketchUnique(0.5);
}
// if with_error is true, the test tolerates floating point error
void TestQuantileElemRank(int32_t device, Span<SketchEntry const> in,
Span<bst_row_t const> d_columns_ptr, bool with_error = false) {
dh::LaunchN(device, in.size(), [=]XGBOOST_DEVICE(size_t idx) {
auto column_id = dh::SegmentId(d_columns_ptr, idx);
auto in_column = in.subspan(d_columns_ptr[column_id],
d_columns_ptr[column_id + 1] -
d_columns_ptr[column_id]);
auto constexpr kEps = 1e-6f;
idx -= d_columns_ptr[column_id];
float prev_rmin = idx == 0 ? 0.0f : in_column[idx-1].rmin;
float prev_rmax = idx == 0 ? 0.0f : in_column[idx-1].rmax;
float rmin_next = in_column[idx].RMinNext();
if (with_error) {
SPAN_CHECK(in_column[idx].rmin + in_column[idx].rmin * kEps >= prev_rmin);
SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= prev_rmax);
SPAN_CHECK(in_column[idx].rmax + in_column[idx].rmin * kEps >= rmin_next);
} else {
SPAN_CHECK(in_column[idx].rmin >= prev_rmin);
SPAN_CHECK(in_column[idx].rmax >= prev_rmax);
SPAN_CHECK(in_column[idx].rmax >= rmin_next);
}
});
// Force sync to terminate current test instead of a later one.
dh::DebugSyncDevice(__FILE__, __LINE__);
}
TEST(GPUQuantile, Prune) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
SketchContainer sketch(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch);
auto n_cuts = detail::RequiredSampleCutsPerColumn(n_bins, kRows);
ASSERT_EQ(sketch.Data().size(), n_cuts * kCols);
sketch.Prune(n_bins);
if (n_bins <= kRows) {
ASSERT_EQ(sketch.Data().size(), n_bins * kCols);
} else {
// LE because kRows * kCols is pushed into sketch, after removing duplicated entries
// we might not have that much inputs for prune.
ASSERT_LE(sketch.Data().size(), kRows * kCols);
}
// This is not necessarily true for all inputs without calling unique after
// prune.
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(),
sketch.Data().data() + sketch.Data().size(),
detail::SketchUnique{}));
TestQuantileElemRank(0, sketch.Data(), sketch.ColumnsPtr());
});
}
TEST(GPUQuantile, MergeEmpty) {
constexpr size_t kRows = 1000, kCols = 100;
size_t n_bins = 10;
SketchContainer sketch_0(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage_0;
std::string interface_str_0 =
RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(
&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketch(adapter_0.Value(), n_bins,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
std::vector<SketchEntry> entries_before(sketch_0.Data().size());
dh::CopyDeviceSpanToVector(&entries_before, sketch_0.Data());
std::vector<bst_row_t> ptrs_before(sketch_0.ColumnsPtr().size());
dh::CopyDeviceSpanToVector(&ptrs_before, sketch_0.ColumnsPtr());
thrust::device_vector<size_t> columns_ptr(kCols + 1);
// Merge an empty sketch
sketch_0.Merge(dh::ToSpan(columns_ptr), Span<SketchEntry>{});
std::vector<SketchEntry> entries_after(sketch_0.Data().size());
dh::CopyDeviceSpanToVector(&entries_after, sketch_0.Data());
std::vector<bst_row_t> ptrs_after(sketch_0.ColumnsPtr().size());
dh::CopyDeviceSpanToVector(&ptrs_after, sketch_0.ColumnsPtr());
CHECK_EQ(entries_before.size(), entries_after.size());
CHECK_EQ(ptrs_before.size(), ptrs_after.size());
for (size_t i = 0; i < entries_before.size(); ++i) {
CHECK_EQ(entries_before[i].value, entries_after[i].value);
CHECK_EQ(entries_before[i].rmin, entries_after[i].rmin);
CHECK_EQ(entries_before[i].rmax, entries_after[i].rmax);
CHECK_EQ(entries_before[i].wmin, entries_after[i].wmin);
}
for (size_t i = 0; i < ptrs_before.size(); ++i) {
CHECK_EQ(ptrs_before[i], ptrs_after[i]);
}
}
TEST(GPUQuantile, MergeBasic) {
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
SketchContainer sketch_0(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage_0;
std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_0);
SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0);
HostDeviceVector<float> storage_1;
std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_1);
data::CupyAdapter adapter_1(interface_str_1);
AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(), &sketch_1);
size_t size_before_merge = sketch_0.Data().size();
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
if (info.weights_.Size() != 0) {
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr(), true);
sketch_0.FixError();
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr(), false);
} else {
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr());
}
auto columns_ptr = sketch_0.ColumnsPtr();
std::vector<bst_row_t> h_columns_ptr(columns_ptr.size());
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
sketch_0.Unique();
ASSERT_TRUE(
thrust::is_sorted(thrust::device, sketch_0.Data().data(),
sketch_0.Data().data() + sketch_0.Data().size(),
detail::SketchUnique{}));
});
}
void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) {
MetaInfo info;
int32_t seed = 0;
SketchContainer sketch_0(n_bins, cols, rows, 0);
HostDeviceVector<float> storage_0;
std::string interface_str_0 = RandomDataGenerator{rows, cols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_0);
data::CupyAdapter adapter_0(interface_str_0);
AdapterDeviceSketchWeighted(adapter_0.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_0);
size_t f_rows = rows * frac;
SketchContainer sketch_1(n_bins, cols, f_rows, 0);
HostDeviceVector<float> storage_1;
std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0}
.Device(0)
.Seed(seed)
.GenerateArrayInterface(&storage_1);
auto data_1 = storage_1.DeviceSpan();
auto tuple_it = thrust::make_tuple(
thrust::make_counting_iterator<size_t>(0ul), data_1.data());
using Tuple = thrust::tuple<size_t, float>;
auto it = thrust::make_zip_iterator(tuple_it);
thrust::transform(thrust::device, it, it + data_1.size(), data_1.data(),
[=] __device__(Tuple const &tuple) {
auto i = thrust::get<0>(tuple);
if (thrust::get<0>(tuple) % 2 == 0) {
return 0.0f;
} else {
return thrust::get<1>(tuple);
}
});
data::CupyAdapter adapter_1(interface_str_1);
AdapterDeviceSketchWeighted(adapter_1.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_1);
size_t size_before_merge = sketch_0.Data().size();
sketch_0.Merge(sketch_1.ColumnsPtr(), sketch_1.Data());
TestQuantileElemRank(0, sketch_0.Data(), sketch_0.ColumnsPtr());
auto columns_ptr = sketch_0.ColumnsPtr();
std::vector<bst_row_t> h_columns_ptr(columns_ptr.size());
dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr);
ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge);
sketch_0.Unique();
ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch_0.Data().data(),
sketch_0.Data().data() + sketch_0.Data().size(),
detail::SketchUnique{}));
}
TEST(GPUQuantile, MergeDuplicated) {
size_t n_bins = 256;
constexpr size_t kRows = 1000, kCols = 100;
for (float frac = 0.5; frac < 2.5; frac += 0.5) {
TestMergeDuplicated(n_bins, kRows, kCols, frac);
}
}
void InitRabitContext(std::string msg) {
auto n_gpus = AllVisibleGPUs();
auto port = std::getenv("DMLC_TRACKER_PORT");
std::string port_str;
if (port) {
port_str = port;
} else {
LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up.";
return;
}
std::vector<std::string> envs{
"DMLC_TRACKER_PORT=" + port_str,
"DMLC_TRACKER_URI=127.0.0.1",
"DMLC_NUM_WORKER=" + std::to_string(n_gpus)};
char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])};
rabit::Init(3, c_envs);
}
TEST(GPUQuantile, AllReduceBasic) {
// This test is supposed to run by a python test that setups the environment.
std::string msg {"Skipping AllReduce test"};
#if defined(__linux__) && defined(XGBOOST_USE_NCCL)
InitRabitContext(msg);
auto n_gpus = AllVisibleGPUs();
auto world = rabit::GetWorldSize();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
return;
}
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) {
// Set up single node version;
SketchContainer sketch_on_single_node(n_bins, kCols, kRows, 0);
size_t intermediate_num_cuts =
std::min(kRows * world, static_cast<size_t>(n_bins * WQSketch::kFactor));
std::vector<SketchContainer> containers;
for (auto rank = 0; rank < world; ++rank) {
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
containers.emplace_back(n_bins, kCols, kRows, 0);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&containers.back());
}
for (auto& sketch : containers) {
sketch.Prune(intermediate_num_cuts);
sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data());
sketch_on_single_node.FixError();
}
sketch_on_single_node.Unique();
TestQuantileElemRank(0, sketch_on_single_node.Data(),
sketch_on_single_node.ColumnsPtr());
// Set up distributed version. We rely on using rank as seed to generate
// the exact same copy of data.
auto rank = rabit::GetRank();
SketchContainer sketch_distributed(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.Unique();
ASSERT_EQ(sketch_distributed.ColumnsPtr().size(),
sketch_on_single_node.ColumnsPtr().size());
ASSERT_EQ(sketch_distributed.Data().size(),
sketch_on_single_node.Data().size());
TestQuantileElemRank(0, sketch_distributed.Data(),
sketch_distributed.ColumnsPtr());
std::vector<SketchEntry> single_node_data(
sketch_on_single_node.Data().size());
dh::CopyDeviceSpanToVector(&single_node_data, sketch_on_single_node.Data());
std::vector<SketchEntry> distributed_data(sketch_distributed.Data().size());
dh::CopyDeviceSpanToVector(&distributed_data, sketch_distributed.Data());
float Eps = 2e-4 * world;
for (size_t i = 0; i < single_node_data.size(); ++i) {
ASSERT_NEAR(single_node_data[i].value, distributed_data[i].value, Eps);
ASSERT_NEAR(single_node_data[i].rmax, distributed_data[i].rmax, Eps);
ASSERT_NEAR(single_node_data[i].rmin, distributed_data[i].rmin, Eps);
ASSERT_NEAR(single_node_data[i].wmin, distributed_data[i].wmin, Eps);
}
});
rabit::Finalize();
#else
LOG(WARNING) << msg;
return;
#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL)
}
TEST(GPUQuantile, SameOnAllWorkers) {
std::string msg {"Skipping SameOnAllWorkers test"};
#if defined(__linux__) && defined(XGBOOST_USE_NCCL)
InitRabitContext(msg);
auto world = rabit::GetWorldSize();
auto n_gpus = AllVisibleGPUs();
if (world != 1) {
ASSERT_EQ(world, n_gpus);
} else {
return;
}
constexpr size_t kRows = 1000, kCols = 100;
RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins,
MetaInfo const &info) {
auto rank = rabit::GetRank();
SketchContainer sketch_distributed(n_bins, kCols, kRows, 0);
HostDeviceVector<float> storage;
std::string interface_str = RandomDataGenerator{kRows, kCols, 0}
.Device(0)
.Seed(rank + seed)
.GenerateArrayInterface(&storage);
data::CupyAdapter adapter(interface_str);
AdapterDeviceSketchWeighted(adapter.Value(), n_bins, info,
std::numeric_limits<float>::quiet_NaN(),
&sketch_distributed);
sketch_distributed.AllReduce();
sketch_distributed.Unique();
TestQuantileElemRank(0, sketch_distributed.Data(), sketch_distributed.ColumnsPtr());
// Test for all workers having the same sketch.
size_t n_data = sketch_distributed.Data().size();
rabit::Allreduce<rabit::op::Max>(&n_data, 1);
ASSERT_EQ(n_data, sketch_distributed.Data().size());
size_t size_as_float =
sketch_distributed.Data().size_bytes() / sizeof(float);
auto local_data = Span<float const>{
reinterpret_cast<float const *>(sketch_distributed.Data().data()),
size_as_float};
dh::caching_device_vector<float> all_workers(size_as_float * world);
thrust::fill(all_workers.begin(), all_workers.end(), 0);
thrust::copy(thrust::device, local_data.data(),
local_data.data() + local_data.size(),
all_workers.begin() + local_data.size() * rank);
dh::AllReducer reducer;
reducer.Init(0);
reducer.AllReduceSum(all_workers.data().get(), all_workers.data().get(),
all_workers.size());
reducer.Synchronize();
auto base_line = dh::ToSpan(all_workers).subspan(0, size_as_float);
std::vector<float> h_base_line(base_line.size());
dh::CopyDeviceSpanToVector(&h_base_line, base_line);
size_t offset = 0;
for (size_t i = 0; i < world; ++i) {
auto comp = dh::ToSpan(all_workers).subspan(offset, size_as_float);
std::vector<float> h_comp(comp.size());
dh::CopyDeviceSpanToVector(&h_comp, comp);
ASSERT_EQ(comp.size(), base_line.size());
for (size_t j = 0; j < h_comp.size(); ++j) {
ASSERT_NEAR(h_base_line[j], h_comp[j], kRtEps);
}
offset += size_as_float;
}
});
#else
LOG(WARNING) << msg;
return;
#endif // !defined(__linux__) && defined(XGBOOST_USE_NCCL)
}
} // namespace common
} // namespace xgboost

View File

@ -422,11 +422,11 @@ TEST(Span, Subspan) {
ASSERT_EQ(s4.size(), s1.size() - 2); ASSERT_EQ(s4.size(), s1.size() - 2);
EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan(-1, 0), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan(16, 0), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan(17, 0), "\\[xgboost\\] Condition .* failed.\n");
auto constexpr kOne = static_cast<Span<int, 4>::index_type>(-1); auto constexpr kOne = static_cast<Span<int, 4>::index_type>(-1);
EXPECT_DEATH(s1.subspan<kOne>(), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan<kOne>(), "\\[xgboost\\] Condition .* failed.\n");
EXPECT_DEATH(s1.subspan<16>(), "\\[xgboost\\] Condition .* failed.\n"); EXPECT_DEATH(s1.subspan<17>(), "\\[xgboost\\] Condition .* failed.\n");
} }
TEST(Span, Compare) { TEST(Span, Compare) {

0
tests/cpp/common/test_threading_utils.cc Executable file → Normal file
View File

View File

@ -168,7 +168,9 @@ class SimpleRealUniformDistribution {
ResultT operator()(GeneratorT* rng) const { ResultT operator()(GeneratorT* rng) const {
ResultT tmp = GenerateCanonical<std::numeric_limits<ResultT>::digits, ResultT tmp = GenerateCanonical<std::numeric_limits<ResultT>::digits,
GeneratorT>(rng); GeneratorT>(rng);
return (tmp * (upper_ - lower_)) + lower_; auto ret = (tmp * (upper_ - lower_)) + lower_;
// Correct floating point error.
return std::max(ret, lower_);
} }
}; };

View File

@ -2,3 +2,4 @@
markers = markers =
mgpu: Mark a test that requires multiple GPUs to run. mgpu: Mark a test that requires multiple GPUs to run.
ci: Mark a test that runs only on CI. ci: Mark a test that runs only on CI.
gtest: Mark a test that requires C++ Google Test executable.

View File

@ -1,8 +1,12 @@
import sys import sys
import os
import pytest import pytest
import numpy as np import numpy as np
import unittest import unittest
import xgboost import xgboost
import subprocess
from hypothesis import given, strategies, settings, note
from test_gpu_updaters import parameter_strategy
if sys.platform.startswith("win"): if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows", allow_module_level=True) pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
@ -12,11 +16,13 @@ from test_with_dask import run_empty_dmatrix # noqa
from test_with_dask import generate_array # noqa from test_with_dask import generate_array # noqa
import testing as tm # noqa import testing as tm # noqa
try: try:
import dask.dataframe as dd import dask.dataframe as dd
from xgboost import dask as dxgb from xgboost import dask as dxgb
from dask_cuda import LocalCUDACluster from dask_cuda import LocalCUDACluster
from dask.distributed import Client from dask.distributed import Client
from dask import array as da
import cudf import cudf
except ImportError: except ImportError:
pass pass
@ -42,7 +48,8 @@ class TestDistributedGPU(unittest.TestCase):
y = y.map_partitions(cudf.from_pandas) y = y.map_partitions(cudf.from_pandas)
dtrain = dxgb.DaskDMatrix(client, X, y) dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist'}, out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True},
dtrain=dtrain, dtrain=dtrain,
evals=[(dtrain, 'X')], evals=[(dtrain, 'X')],
num_boost_round=4) num_boost_round=4)
@ -61,7 +68,8 @@ class TestDistributedGPU(unittest.TestCase):
xgboost.DMatrix(X.compute())) xgboost.DMatrix(X.compute()))
cp.testing.assert_allclose(single_node, predictions) cp.testing.assert_allclose(single_node, predictions)
np.testing.assert_allclose(single_node, series_predictions.to_array()) np.testing.assert_allclose(single_node,
series_predictions.to_array())
predt = dxgb.predict(client, out, X) predt = dxgb.predict(client, out, X)
assert isinstance(predt, dd.Series) assert isinstance(predt, dd.Series)
@ -77,6 +85,41 @@ class TestDistributedGPU(unittest.TestCase):
cp.testing.assert_allclose( cp.testing.assert_allclose(
predt.values.compute(), single_node) predt.values.compute(), single_node)
@given(parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@settings(deadline=None)
@pytest.mark.mgpu
def test_gpu_hist(self, params, num_rounds, dataset):
with LocalCUDACluster(n_workers=2) as cluster:
with Client(cluster) as client:
params['tree_method'] = 'gpu_hist'
params = dataset.set_params(params)
# multi class doesn't handle empty dataset well (empty
# means at least 1 worker has data).
if params['objective'] == "multi:softmax":
return
# It doesn't make sense to distribute a completely
# empty dataset.
if dataset.X.shape[0] == 0:
return
chunk = 128
X = da.from_array(dataset.X,
chunks=(chunk, dataset.X.shape[1]))
y = da.from_array(dataset.y, chunks=(chunk, ))
if dataset.w is not None:
w = da.from_array(dataset.w, chunks=(chunk, ))
else:
w = None
m = dxgb.DaskDMatrix(
client, data=X, label=y, weight=w)
history = dxgb.train(client, params=params, dtrain=m,
num_boost_round=num_rounds,
evals=[(m, 'train')])['history']
note(history)
assert tm.non_increasing(history['train'][dataset.metric])
@pytest.mark.skipif(**tm.no_cupy()) @pytest.mark.skipif(**tm.no_cupy())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_dask_array(self): def test_dask_array(self):
@ -89,7 +132,8 @@ class TestDistributedGPU(unittest.TestCase):
X = X.map_blocks(cp.asarray) X = X.map_blocks(cp.asarray)
y = y.map_blocks(cp.asarray) y = y.map_blocks(cp.asarray)
dtrain = dxgb.DaskDMatrix(client, X, y) dtrain = dxgb.DaskDMatrix(client, X, y)
out = dxgb.train(client, {'tree_method': 'gpu_hist'}, out = dxgb.train(client, {'tree_method': 'gpu_hist',
'debug_synchronize': True},
dtrain=dtrain, dtrain=dtrain,
evals=[(dtrain, 'X')], evals=[(dtrain, 'X')],
num_boost_round=2) num_boost_round=2)
@ -107,12 +151,62 @@ class TestDistributedGPU(unittest.TestCase):
single_node, single_node,
inplace_predictions) inplace_predictions)
@pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.skipif(**tm.no_dask_cuda())
@pytest.mark.mgpu @pytest.mark.mgpu
def test_empty_dmatrix(self): def test_empty_dmatrix(self):
with LocalCUDACluster() as cluster: with LocalCUDACluster() as cluster:
with Client(cluster) as client: with Client(cluster) as client:
parameters = {'tree_method': 'gpu_hist'} parameters = {'tree_method': 'gpu_hist',
'debug_synchronize': True}
run_empty_dmatrix(client, parameters) run_empty_dmatrix(client, parameters)
def run_quantile(self, name):
if sys.platform.startswith("win"):
pytest.skip("Skipping dask tests on Windows")
exe = None
for possible_path in {'./testxgboost', './build/testxgboost',
'../build/testxgboost', '../gpu-build/testxgboost'}:
if os.path.exists(possible_path):
exe = possible_path
assert exe, 'No testxgboost executable found.'
test = "--gtest_filter=GPUQuantile." + name
def runit(worker_addr, rabit_args):
port = None
# setup environment for running the c++ part.
for arg in rabit_args:
if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'):
port = arg.decode('utf-8')
port = port.split('=')
env = os.environ.copy()
env[port[0]] = port[1]
return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE)
with LocalCUDACluster() as cluster:
with Client(cluster) as client:
workers = list(dxgb._get_client_workers(client).keys())
rabit_args = dxgb._get_rabit_args(workers, client)
futures = client.map(runit,
workers,
pure=False,
workers=workers,
rabit_args=rabit_args)
results = client.gather(futures)
for ret in results:
msg = ret.stdout.decode('utf-8')
assert msg.find('1 test from GPUQuantile') != -1, msg
assert ret.returncode == 0, msg
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.mgpu
@pytest.mark.gtest
def test_quantile_basic(self):
self.run_quantile('AllReduceBasic')
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.mgpu
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self):
self.run_quantile('SameOnAllWorkers')

View File

@ -1,10 +1,12 @@
# coding: utf-8 # coding: utf-8
import os
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
from xgboost.compat import DASK_INSTALLED from xgboost.compat import DASK_INSTALLED
from hypothesis import strategies from hypothesis import strategies
from hypothesis.extra.numpy import arrays from hypothesis.extra.numpy import arrays
from joblib import Memory from joblib import Memory
from sklearn import datasets from sklearn import datasets
import tempfile
import xgboost as xgb import xgboost as xgb
import numpy as np import numpy as np
@ -123,10 +125,15 @@ class TestDataset:
return xgb.DeviceQuantileDMatrix(X, y, w) return xgb.DeviceQuantileDMatrix(X, y, w)
def get_external_dmat(self): def get_external_dmat(self):
np.savetxt('tmptmp_1234.csv', np.hstack((self.y.reshape(len(self.y), 1), self.X)), with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'tmptmp_1234.csv')
np.savetxt(path,
np.hstack((self.y.reshape(len(self.y), 1), self.X)),
delimiter=',') delimiter=',')
return xgb.DMatrix('tmptmp_1234.csv?format=csv&label_column=0#tmptmp_', uri = path + '?format=csv&label_column=0#tmptmp_'
weight=self.w) # The uri looks like:
# 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_'
return xgb.DMatrix(uri, weight=self.w)
def __repr__(self): def __repr__(self):
return self.name return self.name
@ -181,6 +188,7 @@ def _dataset_and_weight(draw):
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))) data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
return data return data
# A strategy for drawing from a set of example datasets # A strategy for drawing from a set of example datasets
# May add random weights to the dataset # May add random weights to the dataset
dataset_strategy = _dataset_and_weight() dataset_strategy = _dataset_and_weight()