Implement GK sketching on GPU. (#5846)

* Implement GK sketching on GPU.
* Strong tests on quantile building.
* Handle sparse dataset by binary searching the column index.
* Hypothesis test on dask.
This commit is contained in:
Jiaming Yuan
2020-07-07 12:16:21 +08:00
committed by GitHub
parent ac3f0e78dc
commit 048d969be4
25 changed files with 2045 additions and 405 deletions

View File

@@ -5,10 +5,16 @@
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/device_malloc_allocator.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include <thrust/execution_policy.h>
#include <thrust/transform_scan.h>
#include <thrust/logical.h>
#include <thrust/gather.h>
#include <thrust/unique.h>
#include <thrust/binary_search.h>
#include <rabit/rabit.h>
@@ -53,6 +59,36 @@ __device__ __forceinline__ double atomicAdd(double* address, double val) { // N
}
#endif
namespace dh {
namespace detail {
template <size_t size>
struct AtomicDispatcher;
template <>
struct AtomicDispatcher<sizeof(uint32_t)> {
using Type = unsigned int; // NOLINT
static_assert(sizeof(Type) == sizeof(uint32_t), "Unsigned should be of size 32 bits.");
};
template <>
struct AtomicDispatcher<sizeof(uint64_t)> {
using Type = unsigned long long; // NOLINT
static_assert(sizeof(Type) == sizeof(uint64_t), "Unsigned long long should be of size 64 bits.");
};
} // namespace detail
} // namespace dh
// atomicAdd is not defined for size_t.
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> * = // NOLINT
nullptr>
T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT
using Type = typename dh::detail::AtomicDispatcher<sizeof(T)>::Type;
Type ret = ::atomicAdd(reinterpret_cast<Type *>(addr), static_cast<Type>(v));
return static_cast<T>(ret);
}
namespace dh {
#define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__
@@ -291,10 +327,12 @@ public:
safe_cuda(cudaGetDevice(&current_device));
stats_.RegisterDeallocation(ptr, n, current_device);
}
size_t PeakMemory()
{
size_t PeakMemory() const {
return stats_.peak_allocated_bytes;
}
size_t CurrentlyAllocatedBytes() const {
return stats_.currently_allocated_bytes;
}
void Clear()
{
stats_ = DeviceStats();
@@ -529,7 +567,6 @@ class AllReducer {
bool initialised_ {false};
size_t allreduce_bytes_ {0}; // Keep statistics of the number of bytes communicated
size_t allreduce_calls_ {0}; // Keep statistics of the number of reduce calls
std::vector<size_t> host_data_; // Used for all reduce on host
#ifdef XGBOOST_USE_NCCL
ncclComm_t comm_;
cudaStream_t stream_;
@@ -569,6 +606,27 @@ class AllReducer {
#endif
}
/**
* \brief Allgather implemented as grouped calls to Broadcast. This way we can accept
* different size of data on different workers.
* \param length_bytes Size of input data in bytes.
* \param segments Size of data on each worker.
* \param recvbuf Buffer storing the result of data from all workers.
*/
void AllGather(void const* data, size_t length_bytes,
std::vector<size_t>* segments, dh::caching_device_vector<char>* recvbuf);
void AllGather(uint32_t const* data, size_t length,
dh::caching_device_vector<uint32_t>* recvbuf) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
size_t world = rabit::GetWorldSize();
recvbuf->resize(length * world);
safe_nccl(ncclAllGather(data, recvbuf->data().get(), length, ncclUint32,
comm_, stream_));
#endif // XGBOOST_USE_NCCL
}
/**
* \brief Allreduce. Use in exactly the same way as NCCL but without needing
* streams or comms.
@@ -607,6 +665,40 @@ class AllReducer {
#endif
}
void AllReduceSum(const uint32_t *sendbuff, uint32_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint32, ncclSum, comm_, stream_));
#endif
}
void AllReduceSum(const uint64_t *sendbuff, uint64_t *recvbuff, int count) {
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
#endif
}
// Specialization for size_t, which is implementation defined so it might or might not
// be one of uint64_t/uint32_t/unsigned long long/unsigned long.
template <typename T = size_t,
std::enable_if_t<std::is_same<size_t, T>::value &&
!std::is_same<size_t, unsigned long long>::value> // NOLINT
* = nullptr>
void AllReduceSum(const T *sendbuff, T *recvbuff, int count) { // NOLINT
#ifdef XGBOOST_USE_NCCL
CHECK(initialised_);
dh::safe_cuda(cudaSetDevice(device_ordinal_));
static_assert(sizeof(unsigned long long) == sizeof(uint64_t), ""); // NOLINT
dh::safe_nccl(ncclAllReduce(sendbuff, recvbuff, count, ncclUint64, ncclSum, comm_, stream_));
#endif
}
/**
* \fn void Synchronize()
*
@@ -886,9 +978,86 @@ DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
// Thrust version of this function causes error on Windows
template <typename ReturnT, typename IterT, typename FuncT>
thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIterator(
IterT iter, FuncT func) {
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}
template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) -
1 - first;
return segment_id;
}
template <typename T>
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
namespace detail {
template <typename Key, typename KeyOutIt>
struct SegmentedUniqueReduceOp {
KeyOutIt key_out;
__device__ Key const& operator()(Key const& key) const {
auto constexpr kOne = static_cast<std::remove_reference_t<decltype(*(key_out + key.first))>>(1);
atomicAdd(&(*(key_out + key.first)), kOne);
return key;
}
};
} // namespace detail
/* \brief Segmented unique function. Keys are pointers to segments with key_segments_last -
* key_segments_first = n_segments + 1.
*
* \pre Input segment and output segment must not overlap.
*
* \param key_segments_first Beginning iterator of segments.
* \param key_segments_last End iterator of segments.
* \param val_first Beginning iterator of values.
* \param val_last End iterator of values.
* \param key_segments_out Output iterator of segments.
* \param val_out Output iterator of values.
*
* \return Number of unique values in total.
*/
template <typename KeyInIt, typename KeyOutIt, typename ValInIt,
typename ValOutIt, typename Comp>
size_t
SegmentedUnique(KeyInIt key_segments_first, KeyInIt key_segments_last, ValInIt val_first,
ValInIt val_last, KeyOutIt key_segments_out, ValOutIt val_out,
Comp comp) {
using Key = thrust::pair<size_t, typename thrust::iterator_traits<ValInIt>::value_type>;
dh::XGBCachingDeviceAllocator<char> alloc;
auto unique_key_it = dh::MakeTransformIterator<Key>(
thrust::make_counting_iterator(static_cast<size_t>(0)),
[=] __device__(size_t i) {
size_t seg = dh::SegmentId(key_segments_first, key_segments_last, i);
return thrust::make_pair(seg, *(val_first + i));
});
size_t segments_len = key_segments_last - key_segments_first;
thrust::fill(thrust::device, key_segments_out, key_segments_out + segments_len, 0);
size_t n_inputs = std::distance(val_first, val_last);
// Reduce the number of uniques elements per segment, avoid creating an intermediate
// array for `reduce_by_key`. It's limited by the types that atomicAdd supports. For
// example, size_t is not supported as of CUDA 10.2.
auto reduce_it = thrust::make_transform_output_iterator(
thrust::make_discard_iterator(),
detail::SegmentedUniqueReduceOp<Key, KeyOutIt>{key_segments_out});
auto uniques_ret = thrust::unique_by_key_copy(
thrust::cuda::par(alloc), unique_key_it, unique_key_it + n_inputs,
val_first, reduce_it, val_out,
[=] __device__(Key const &l, Key const &r) {
if (l.first == r.first) {
// In the same segment.
return comp(l.second, r.second);
}
return false;
});
auto n_uniques = uniques_ret.second - val_out;
CHECK_LE(n_uniques, n_inputs);
thrust::exclusive_scan(thrust::cuda::par(alloc), key_segments_out,
key_segments_out + segments_len, key_segments_out, 0);
return n_uniques;
}
} // namespace dh