Make binary bin search reusable. (#6058)
* Move binary search row to hist util. * Remove dead code.
This commit is contained in:
parent
9c14e430af
commit
80c8547147
@ -84,6 +84,14 @@
|
|||||||
#define XGBOOST_DEVICE
|
#define XGBOOST_DEVICE
|
||||||
#endif // defined (__CUDA__) || defined(__NVCC__)
|
#endif // defined (__CUDA__) || defined(__NVCC__)
|
||||||
|
|
||||||
|
#if defined(__CUDA__) || defined(__CUDACC__)
|
||||||
|
#define XGBOOST_HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__
|
||||||
|
#define XGBOOST_DEV_INLINE __device__ __forceinline__
|
||||||
|
#else
|
||||||
|
#define XGBOOST_HOST_DEV_INLINE
|
||||||
|
#define XGBOOST_DEV_INLINE
|
||||||
|
#endif // defined(__CUDA__) || defined(__CUDACC__)
|
||||||
|
|
||||||
// These check are for Makefile.
|
// These check are for Makefile.
|
||||||
#if !defined(XGBOOST_MM_PREFETCH_PRESENT) && !defined(XGBOOST_BUILTIN_PREFETCH_PRESENT)
|
#if !defined(XGBOOST_MM_PREFETCH_PRESENT) && !defined(XGBOOST_BUILTIN_PREFETCH_PRESENT)
|
||||||
/* default logic for software pre-fetching */
|
/* default logic for software pre-fetching */
|
||||||
|
|||||||
@ -96,9 +96,6 @@ T __device__ __forceinline__ atomicAdd(T *addr, T v) { // NOLINT
|
|||||||
|
|
||||||
namespace dh {
|
namespace dh {
|
||||||
|
|
||||||
#define HOST_DEV_INLINE XGBOOST_DEVICE __forceinline__
|
|
||||||
#define DEV_INLINE __device__ __forceinline__
|
|
||||||
|
|
||||||
#ifdef XGBOOST_USE_NCCL
|
#ifdef XGBOOST_USE_NCCL
|
||||||
#define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__)
|
#define safe_nccl(ans) ThrowOnNcclError((ans), __FILE__, __LINE__)
|
||||||
|
|
||||||
@ -184,9 +181,11 @@ inline void CheckComputeCapability() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
|
XGBOOST_DEV_INLINE void AtomicOrByte(unsigned int *__restrict__ buffer,
|
||||||
|
size_t ibyte, unsigned char b) {
|
||||||
atomicOr(&buffer[ibyte / sizeof(unsigned int)],
|
atomicOr(&buffer[ibyte / sizeof(unsigned int)],
|
||||||
static_cast<unsigned int>(b) << (ibyte % (sizeof(unsigned int)) * 8));
|
static_cast<unsigned int>(b)
|
||||||
|
<< (ibyte % (sizeof(unsigned int)) * 8));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -994,7 +993,7 @@ class SegmentSorter {
|
|||||||
|
|
||||||
// Atomic add function for gradients
|
// Atomic add function for gradients
|
||||||
template <typename OutputGradientT, typename InputGradientT>
|
template <typename OutputGradientT, typename InputGradientT>
|
||||||
DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
|
||||||
const InputGradientT& gpair) {
|
const InputGradientT& gpair) {
|
||||||
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
|
auto dst_ptr = reinterpret_cast<typename OutputGradientT::ValueT*>(dest);
|
||||||
|
|
||||||
|
|||||||
@ -281,6 +281,33 @@ struct GHistIndexMatrix {
|
|||||||
bool isDense_;
|
bool isDense_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename GradientIndex>
|
||||||
|
int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(bst_uint begin, bst_uint end,
|
||||||
|
GradientIndex const &data,
|
||||||
|
uint32_t const fidx_begin,
|
||||||
|
uint32_t const fidx_end) {
|
||||||
|
uint32_t previous_middle = std::numeric_limits<uint32_t>::max();
|
||||||
|
while (end != begin) {
|
||||||
|
auto middle = begin + (end - begin) / 2;
|
||||||
|
if (middle == previous_middle) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
previous_middle = middle;
|
||||||
|
|
||||||
|
auto gidx = data[middle];
|
||||||
|
|
||||||
|
if (gidx >= fidx_begin && gidx < fidx_end) {
|
||||||
|
return static_cast<int32_t>(gidx);
|
||||||
|
} else if (gidx < fidx_begin) {
|
||||||
|
begin = middle;
|
||||||
|
} else {
|
||||||
|
end = middle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Value is missing
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
struct GHistIndexBlock {
|
struct GHistIndexBlock {
|
||||||
const size_t* row_ptr;
|
const size_t* row_ptr;
|
||||||
const uint32_t* index;
|
const uint32_t* index;
|
||||||
|
|||||||
@ -13,34 +13,6 @@
|
|||||||
#include <thrust/binary_search.h>
|
#include <thrust/binary_search.h>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
// Find a gidx value for a given feature otherwise return -1 if not found
|
|
||||||
__forceinline__ __device__ int BinarySearchRow(
|
|
||||||
bst_uint begin, bst_uint end,
|
|
||||||
common::CompressedIterator<uint32_t> data,
|
|
||||||
int const fidx_begin, int const fidx_end) {
|
|
||||||
bst_uint previous_middle = UINT32_MAX;
|
|
||||||
while (end != begin) {
|
|
||||||
auto middle = begin + (end - begin) / 2;
|
|
||||||
if (middle == previous_middle) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
previous_middle = middle;
|
|
||||||
|
|
||||||
auto gidx = data[middle];
|
|
||||||
|
|
||||||
if (gidx >= fidx_begin && gidx < fidx_end) {
|
|
||||||
return gidx;
|
|
||||||
} else if (gidx < fidx_begin) {
|
|
||||||
begin = middle;
|
|
||||||
} else {
|
|
||||||
end = middle;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Value is missing
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
/** \brief Struct for accessing and manipulating an ellpack matrix on the
|
||||||
* device. Does not own underlying memory and may be trivially copied into
|
* device. Does not own underlying memory and may be trivially copied into
|
||||||
* kernels.*/
|
* kernels.*/
|
||||||
@ -83,7 +55,7 @@ struct EllpackDeviceAccessor {
|
|||||||
if (is_dense) {
|
if (is_dense) {
|
||||||
gidx = gidx_iter[row_begin + fidx];
|
gidx = gidx_iter[row_begin + fidx];
|
||||||
} else {
|
} else {
|
||||||
gidx = BinarySearchRow(row_begin,
|
gidx = common::BinarySearchBin(row_begin,
|
||||||
row_end,
|
row_end,
|
||||||
gidx_iter,
|
gidx_iter,
|
||||||
feature_segments[fidx],
|
feature_segments[fidx],
|
||||||
|
|||||||
@ -134,7 +134,7 @@ struct DeviceAdapterLoader {
|
|||||||
|
|
||||||
using BatchT = Batch;
|
using BatchT = Batch;
|
||||||
|
|
||||||
DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
||||||
bst_feature_t num_features, bst_row_t num_rows,
|
bst_feature_t num_features, bst_row_t num_rows,
|
||||||
size_t entry_start) :
|
size_t entry_start) :
|
||||||
batch{batch},
|
batch{batch},
|
||||||
@ -158,7 +158,7 @@ struct DeviceAdapterLoader {
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
|
XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
|
||||||
if (use_shared) {
|
if (use_shared) {
|
||||||
return smem[threadIdx.x * columns + fidx];
|
return smem[threadIdx.x * columns + fidx];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -34,7 +34,7 @@ namespace tree {
|
|||||||
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
* to avoid outliers, as the full reduction is reproducible on GPU with reduction tree.
|
||||||
*/
|
*/
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) {
|
XGBOOST_DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) {
|
||||||
T delta = max_abs / (static_cast<T>(1.0) - 2 * n * std::numeric_limits<T>::epsilon());
|
T delta = max_abs / (static_cast<T>(1.0) - 2 * n * std::numeric_limits<T>::epsilon());
|
||||||
|
|
||||||
// Calculate ceil(log_2(delta)).
|
// Calculate ceil(log_2(delta)).
|
||||||
@ -53,20 +53,20 @@ struct Pair {
|
|||||||
GradientPair first;
|
GradientPair first;
|
||||||
GradientPair second;
|
GradientPair second;
|
||||||
};
|
};
|
||||||
DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
|
XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
|
||||||
return {lhs.first + rhs.first, lhs.second + rhs.second};
|
return {lhs.first + rhs.first, lhs.second + rhs.second};
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
struct Clip : public thrust::unary_function<GradientPair, Pair> {
|
||||||
static DEV_INLINE float Pclip(float v) {
|
static XGBOOST_DEV_INLINE float Pclip(float v) {
|
||||||
return v > 0 ? v : 0;
|
return v > 0 ? v : 0;
|
||||||
}
|
}
|
||||||
static DEV_INLINE float Nclip(float v) {
|
static XGBOOST_DEV_INLINE float Nclip(float v) {
|
||||||
return v < 0 ? abs(v) : 0;
|
return v < 0 ? abs(v) : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
DEV_INLINE Pair operator()(GradientPair x) const {
|
XGBOOST_DEV_INLINE Pair operator()(GradientPair x) const {
|
||||||
auto pg = Pclip(x.GetGrad());
|
auto pg = Pclip(x.GetGrad());
|
||||||
auto ph = Pclip(x.GetHess());
|
auto ph = Pclip(x.GetHess());
|
||||||
|
|
||||||
|
|||||||
@ -16,7 +16,7 @@ template <typename GradientSumT>
|
|||||||
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
GradientSumT CreateRoundingFactor(common::Span<GradientPair const> gpair);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) {
|
XGBOOST_DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) {
|
||||||
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
|
return (rounding_factor + static_cast<T>(x)) - rounding_factor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -114,58 +114,6 @@ struct DeviceSplitCandidateReduceOp {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceNodeStats {
|
|
||||||
GradientPair sum_gradients;
|
|
||||||
float root_gain {-FLT_MAX};
|
|
||||||
float weight {-FLT_MAX};
|
|
||||||
|
|
||||||
/** default direction for missing values */
|
|
||||||
DefaultDirection dir {kLeftDir};
|
|
||||||
/** threshold value for comparison */
|
|
||||||
float fvalue {0.0f};
|
|
||||||
GradientPair left_sum;
|
|
||||||
GradientPair right_sum;
|
|
||||||
/** \brief The feature index. */
|
|
||||||
int fidx{kUnusedNode};
|
|
||||||
/** node id (used as key for reduce/scan) */
|
|
||||||
NodeIdT idx{kUnusedNode};
|
|
||||||
|
|
||||||
XGBOOST_DEVICE DeviceNodeStats() {} // NOLINT
|
|
||||||
|
|
||||||
template <typename ParamT>
|
|
||||||
HOST_DEV_INLINE DeviceNodeStats(GradientPair sum_gradients, NodeIdT nidx,
|
|
||||||
const ParamT& param)
|
|
||||||
: sum_gradients(sum_gradients),
|
|
||||||
idx(nidx) {
|
|
||||||
this->root_gain =
|
|
||||||
CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
|
|
||||||
this->weight =
|
|
||||||
CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir,
|
|
||||||
GradientPair left_sum, GradientPair right_sum) {
|
|
||||||
this->fvalue = fvalue;
|
|
||||||
this->fidx = fidx;
|
|
||||||
this->dir = dir;
|
|
||||||
this->left_sum = left_sum;
|
|
||||||
this->right_sum = right_sum;
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE void SetSplit(const DeviceSplitCandidate& split) {
|
|
||||||
this->SetSplit(split.fvalue, split.findex, split.dir, split.left_sum,
|
|
||||||
split.right_sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Tells whether this node is part of the decision tree */
|
|
||||||
HOST_DEV_INLINE bool IsUnused() const { return (idx == kUnusedNode); }
|
|
||||||
|
|
||||||
/** Tells whether this node is a leaf of the decision tree */
|
|
||||||
HOST_DEV_INLINE bool IsLeaf() const {
|
|
||||||
return (!IsUnused() && (fidx == kUnusedNode));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct SumCallbackOp {
|
struct SumCallbackOp {
|
||||||
// Running prefix
|
// Running prefix
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user