[GPU-Plugin] Various fixes (#2579)
* Fix test large * Add check for max_depth 0 * Update readme * Add LBS specialisation for dense data * Add bst_gpair_precise * Temporarily disable accuracy tests on test_large.py * Solve unused variable compiler warning * Fix max_bin > 1024 error
This commit is contained in:
parent
03e213c7cd
commit
eda9e180f0
@ -82,49 +82,67 @@ typedef uint64_t bst_ulong; // NOLINT(*)
|
|||||||
/*! \brief float type, used for storing statistics */
|
/*! \brief float type, used for storing statistics */
|
||||||
typedef float bst_float;
|
typedef float bst_float;
|
||||||
|
|
||||||
/*! \brief gradient statistics pair usually needed in gradient boosting */
|
|
||||||
struct bst_gpair {
|
/*! \brief Implementation of gradient statistics pair */
|
||||||
|
template <typename T>
|
||||||
|
struct bst_gpair_internal {
|
||||||
/*! \brief gradient statistics */
|
/*! \brief gradient statistics */
|
||||||
bst_float grad;
|
T grad;
|
||||||
/*! \brief second order gradient statistics */
|
/*! \brief second order gradient statistics */
|
||||||
bst_float hess;
|
T hess;
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair() : grad(0), hess(0) {}
|
XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair(bst_float grad, bst_float hess)
|
XGBOOST_DEVICE bst_gpair_internal(T grad, T hess)
|
||||||
: grad(grad), hess(hess) {}
|
: grad(grad), hess(hess) {}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair &operator+=(const bst_gpair &rhs) {
|
template <typename T2>
|
||||||
|
XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal<T2>&g)
|
||||||
|
: grad(g.grad), hess(g.hess) {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(const bst_gpair_internal<T> &rhs) {
|
||||||
grad += rhs.grad;
|
grad += rhs.grad;
|
||||||
hess += rhs.hess;
|
hess += rhs.hess;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair operator+(const bst_gpair &rhs) const {
|
XGBOOST_DEVICE bst_gpair_internal<T> operator+(const bst_gpair_internal<T> &rhs) const {
|
||||||
bst_gpair g;
|
bst_gpair_internal<T> g;
|
||||||
g.grad = grad + rhs.grad;
|
g.grad = grad + rhs.grad;
|
||||||
g.hess = hess + rhs.hess;
|
g.hess = hess + rhs.hess;
|
||||||
return g;
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair &operator-=(const bst_gpair &rhs) {
|
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(const bst_gpair_internal<T> &rhs) {
|
||||||
grad -= rhs.grad;
|
grad -= rhs.grad;
|
||||||
hess -= rhs.hess;
|
hess -= rhs.hess;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair operator-(const bst_gpair &rhs) const {
|
XGBOOST_DEVICE bst_gpair_internal<T> operator-(const bst_gpair_internal<T> &rhs) const {
|
||||||
bst_gpair g;
|
bst_gpair_internal<T> g;
|
||||||
g.grad = grad - rhs.grad;
|
g.grad = grad - rhs.grad;
|
||||||
g.hess = hess - rhs.hess;
|
g.hess = hess - rhs.hess;
|
||||||
return g;
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair(int value) {
|
XGBOOST_DEVICE bst_gpair_internal(int value) {
|
||||||
*this = bst_gpair(static_cast<float>(value), static_cast<float>(value));
|
*this = bst_gpair_internal<T>(static_cast<float>(value), static_cast<float>(value));
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream &operator<<(std::ostream &os,
|
||||||
|
const bst_gpair_internal<T> &g) {
|
||||||
|
os << g.grad << "/" << g.hess;
|
||||||
|
return os;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*! \brief gradient statistics pair usually needed in gradient boosting */
|
||||||
|
typedef bst_gpair_internal<float> bst_gpair;
|
||||||
|
|
||||||
|
/*! \brief High precision gradient statistics pair */
|
||||||
|
typedef bst_gpair_internal<double> bst_gpair_precise;
|
||||||
|
|
||||||
/*! \brief small eps gap for minimum split decision. */
|
/*! \brief small eps gap for minimum split decision. */
|
||||||
const bst_float rt_eps = 1e-6f;
|
const bst_float rt_eps = 1e-6f;
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ This plugin currently works with the CLI version and python version.
|
|||||||
|
|
||||||
Python example:
|
Python example:
|
||||||
```python
|
```python
|
||||||
param['gpu_id'] = 1
|
param['gpu_id'] = 0
|
||||||
param['max_bin'] = 16
|
param['max_bin'] = 16
|
||||||
param['tree_method'] = 'gpu_hist'
|
param['tree_method'] = 'gpu_hist'
|
||||||
```
|
```
|
||||||
|
|||||||
@ -15,28 +15,30 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
|
template <typename gpair_t>
|
||||||
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
||||||
const bst_gpair& scan,
|
const gpair_t& scan,
|
||||||
const bst_gpair& missing,
|
const gpair_t& missing,
|
||||||
const bst_gpair& parent_sum,
|
const gpair_t& parent_sum,
|
||||||
const float& parent_gain,
|
const float& parent_gain,
|
||||||
bool missing_left) {
|
bool missing_left) {
|
||||||
bst_gpair left = scan;
|
gpair_t left = scan;
|
||||||
|
|
||||||
if (missing_left) {
|
if (missing_left) {
|
||||||
left += missing;
|
left += missing;
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_gpair right = parent_sum - left;
|
gpair_t right = parent_sum - left;
|
||||||
|
|
||||||
float left_gain = CalcGain(param, left.grad, left.hess);
|
float left_gain = CalcGain(param, left.grad, left.hess);
|
||||||
float right_gain = CalcGain(param, right.grad, right.hess);
|
float right_gain = CalcGain(param, right.grad, right.hess);
|
||||||
return left_gain + right_gain - parent_gain;
|
return left_gain + right_gain - parent_gain;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ float inline loss_chg_missing(const bst_gpair& scan,
|
template <typename gpair_t>
|
||||||
const bst_gpair& missing,
|
__device__ float inline loss_chg_missing(const gpair_t& scan,
|
||||||
const bst_gpair& parent_sum,
|
const gpair_t& missing,
|
||||||
|
const gpair_t& parent_sum,
|
||||||
const float& parent_gain,
|
const float& parent_gain,
|
||||||
const GPUTrainingParam& param,
|
const GPUTrainingParam& param,
|
||||||
bool& missing_left_out) { // NOLINT
|
bool& missing_left_out) { // NOLINT
|
||||||
@ -177,11 +179,11 @@ inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
|
|||||||
}
|
}
|
||||||
struct GpairCallbackOp {
|
struct GpairCallbackOp {
|
||||||
// Running prefix
|
// Running prefix
|
||||||
bst_gpair running_total;
|
bst_gpair_precise running_total;
|
||||||
// Constructor
|
// Constructor
|
||||||
__device__ GpairCallbackOp() : running_total(bst_gpair()) {}
|
__device__ GpairCallbackOp() : running_total(bst_gpair_precise()) {}
|
||||||
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
|
__device__ bst_gpair_precise operator()(bst_gpair_precise block_aggregate) {
|
||||||
bst_gpair old_prefix = running_total;
|
bst_gpair_precise old_prefix = running_total;
|
||||||
running_total += block_aggregate;
|
running_total += block_aggregate;
|
||||||
return old_prefix;
|
return old_prefix;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -729,32 +729,8 @@ __global__ void LbsKernel(coordinate_t *d_coordinates,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* \fn template <typename func_t, typename segments_iter, typename offset_t>
|
|
||||||
* void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
|
||||||
* segments_iter segments, offset_t num_segments, func_t f)
|
|
||||||
*
|
|
||||||
* \brief Load balancing search function. Reads a CSR type matrix description
|
|
||||||
* and allows a function to be executed on each element. Search 'modern GPU load
|
|
||||||
* balancing search' for more information.
|
|
||||||
*
|
|
||||||
* \author Rory
|
|
||||||
* \date 7/9/2017
|
|
||||||
*
|
|
||||||
* \tparam func_t Type of the function t.
|
|
||||||
* \tparam segments_iter Type of the segments iterator.
|
|
||||||
* \tparam offset_t Type of the offset.
|
|
||||||
* \tparam segments_t Type of the segments t.
|
|
||||||
* \param device_idx Zero-based index of the device.
|
|
||||||
* \param [in,out] temp_memory Temporary memory allocator.
|
|
||||||
* \param count Number of elements.
|
|
||||||
* \param segments Device pointer to segments.
|
|
||||||
* \param num_segments Number of segments.
|
|
||||||
* \param f Lambda to be executed on matrix elements.
|
|
||||||
*/
|
|
||||||
|
|
||||||
template <typename func_t, typename segments_iter, typename offset_t>
|
template <typename func_t, typename segments_iter, typename offset_t>
|
||||||
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
void SparseTransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
||||||
segments_iter segments, offset_t num_segments, func_t f) {
|
segments_iter segments, offset_t num_segments, func_t f) {
|
||||||
typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t;
|
typedef typename cub::CubVector<offset_t, 2>::Type coordinate_t;
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
@ -775,4 +751,47 @@ void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
|||||||
num_segments);
|
num_segments);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename func_t, typename offset_t>
|
||||||
|
void DenseTransformLbs(int device_idx, offset_t count, offset_t num_segments, func_t f) {
|
||||||
|
CHECK(count % num_segments == 0) << "Data is not dense.";
|
||||||
|
|
||||||
|
launch_n(device_idx, count, [=]__device__(offset_t idx)
|
||||||
|
{
|
||||||
|
offset_t segment = idx / (count / num_segments);
|
||||||
|
f(idx, segment);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \fn template <typename func_t, typename segments_iter, typename offset_t> void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count, segments_iter segments, offset_t num_segments, bool is_dense, func_t f)
|
||||||
|
*
|
||||||
|
* \brief Load balancing search function. Reads a CSR type matrix description and allows a function
|
||||||
|
* to be executed on each element. Search 'modern GPU load balancing search' for more
|
||||||
|
* information.
|
||||||
|
*
|
||||||
|
* \author Rory
|
||||||
|
* \date 7/9/2017
|
||||||
|
*
|
||||||
|
* \tparam func_t Type of the function t.
|
||||||
|
* \tparam segments_iter Type of the segments iterator.
|
||||||
|
* \tparam offset_t Type of the offset.
|
||||||
|
* \param device_idx Zero-based index of the device.
|
||||||
|
* \param [in,out] temp_memory Temporary memory allocator.
|
||||||
|
* \param count Number of elements.
|
||||||
|
* \param segments Device pointer to segments.
|
||||||
|
* \param num_segments Number of segments.
|
||||||
|
* \param is_dense True if this object is dense.
|
||||||
|
* \param f Lambda to be executed on matrix elements.
|
||||||
|
*/
|
||||||
|
|
||||||
|
template <typename func_t, typename segments_iter, typename offset_t>
|
||||||
|
void TransformLbs(int device_idx, dh::CubMemory *temp_memory, offset_t count,
|
||||||
|
segments_iter segments, offset_t num_segments, bool is_dense, func_t f) {
|
||||||
|
if (is_dense) {
|
||||||
|
DenseTransformLbs(device_idx, count, num_segments, f);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
SparseTransformLbs(device_idx, temp_memory, count, segments, num_segments, f);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace dh
|
} // namespace dh
|
||||||
|
|||||||
@ -49,10 +49,10 @@ void DeviceHist::Init(int n_bins_in) {
|
|||||||
|
|
||||||
void DeviceHist::Reset(int device_idx) {
|
void DeviceHist::Reset(int device_idx) {
|
||||||
cudaSetDevice(device_idx);
|
cudaSetDevice(device_idx);
|
||||||
data.fill(bst_gpair());
|
data.fill(bst_gpair_precise());
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_gpair* DeviceHist::GetLevelPtr(int depth) {
|
bst_gpair_precise* DeviceHist::GetLevelPtr(int depth) {
|
||||||
return data.data() + n_nodes(depth - 1) * n_bins;
|
return data.data() + n_nodes(depth - 1) * n_bins;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -62,10 +62,25 @@ HistBuilder DeviceHist::GetBuilder() {
|
|||||||
return HistBuilder(data.data(), n_bins);
|
return HistBuilder(data.data(), n_bins);
|
||||||
}
|
}
|
||||||
|
|
||||||
HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins)
|
HistBuilder::HistBuilder(bst_gpair_precise* ptr, int n_bins)
|
||||||
: d_hist(ptr), n_bins(n_bins) {}
|
: d_hist(ptr), n_bins(n_bins) {}
|
||||||
|
|
||||||
__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
|
// Define double precision atomic add for older architectures
|
||||||
|
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600
|
||||||
|
#else
|
||||||
|
__device__ double atomicAdd(double* address, double val) {
|
||||||
|
unsigned long long int* address_as_ull = (unsigned long long int*)address; // NOLINT
|
||||||
|
unsigned long long int old = *address_as_ull, assumed; // NOLINT
|
||||||
|
do {
|
||||||
|
assumed = old;
|
||||||
|
old = atomicCAS(address_as_ull, assumed,
|
||||||
|
__double_as_longlong(val + __longlong_as_double(assumed)));
|
||||||
|
} while (assumed != old);
|
||||||
|
return __longlong_as_double(old);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
__device__ void HistBuilder::Add(bst_gpair_precise gpair, int gidx, int nidx) const {
|
||||||
int hist_idx = nidx * n_bins + gidx;
|
int hist_idx = nidx * n_bins + gidx;
|
||||||
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
||||||
// line lead to about 3X
|
// line lead to about 3X
|
||||||
@ -75,7 +90,7 @@ __device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
|
|||||||
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ bst_gpair HistBuilder::Get(int gidx, int nidx) const {
|
__device__ bst_gpair_precise HistBuilder::Get(int gidx, int nidx) const {
|
||||||
return d_hist[nidx * n_bins + gidx];
|
return d_hist[nidx * n_bins + gidx];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,6 +119,7 @@ GPUHistBuilder::~GPUHistBuilder() {
|
|||||||
|
|
||||||
void GPUHistBuilder::Init(const TrainParam& param) {
|
void GPUHistBuilder::Init(const TrainParam& param) {
|
||||||
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
CHECK(param.max_depth < 16) << "Tree depth too large.";
|
||||||
|
CHECK(param.max_depth != 0) << "Tree depth cannot be 0.";
|
||||||
CHECK(param.grow_policy != TrainParam::kLossGuide)
|
CHECK(param.grow_policy != TrainParam::kLossGuide)
|
||||||
<< "Loss guided growth policy not supported. Use CPU algorithm.";
|
<< "Loss guided growth policy not supported. Use CPU algorithm.";
|
||||||
this->param = param;
|
this->param = param;
|
||||||
@ -358,7 +374,7 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
auto hist_builder = hist_vec[d_idx].GetBuilder();
|
auto hist_builder = hist_vec[d_idx].GetBuilder();
|
||||||
dh::TransformLbs(
|
dh::TransformLbs(
|
||||||
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
|
device_idx, &temp_memory[d_idx], end - begin, d_row_ptr,
|
||||||
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
|
row_end - row_begin, is_dense, [=] __device__(size_t local_idx, int local_ridx) {
|
||||||
int nidx = d_position[local_ridx]; // OPTMARK: latency
|
int nidx = d_position[local_ridx]; // OPTMARK: latency
|
||||||
if (!is_active(nidx, depth)) return;
|
if (!is_active(nidx, depth)) return;
|
||||||
|
|
||||||
@ -396,8 +412,8 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
dh::safe_nccl(ncclAllReduce(
|
dh::safe_nccl(ncclAllReduce(
|
||||||
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) / sizeof(float),
|
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair_precise) / sizeof(double),
|
||||||
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx])));
|
ncclDouble, ncclSum, comms[d_idx], *(streams[d_idx])));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
@ -428,9 +444,9 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int gidx = idx % hist_builder.n_bins;
|
int gidx = idx % hist_builder.n_bins;
|
||||||
bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
bst_gpair_precise parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||||
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
||||||
bst_gpair other = hist_builder.Get(gidx, other_nidx);
|
bst_gpair_precise other = hist_builder.Get(gidx, other_nidx);
|
||||||
hist_builder.Add(parent - other, gidx,
|
hist_builder.Add(parent - other, gidx,
|
||||||
nidx); // OPTMARK: This is slow, could use shared
|
nidx); // OPTMARK: This is slow, could use shared
|
||||||
// memory or cache results intead of writing to
|
// memory or cache results intead of writing to
|
||||||
@ -443,16 +459,16 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
|
|
||||||
template <int BLOCK_THREADS>
|
template <int BLOCK_THREADS>
|
||||||
__global__ void find_split_kernel(
|
__global__ void find_split_kernel(
|
||||||
const bst_gpair* d_level_hist, int* d_feature_segments, int depth,
|
const bst_gpair_precise* d_level_hist, int* d_feature_segments, int depth,
|
||||||
int n_features, int n_bins, Node* d_nodes, Node* d_nodes_temp,
|
int n_features, int n_bins, Node* d_nodes, Node* d_nodes_temp,
|
||||||
Node* d_nodes_child_temp, int nodes_offset_device, float* d_fidx_min_map,
|
Node* d_nodes_child_temp, int nodes_offset_device, float* d_fidx_min_map,
|
||||||
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
||||||
bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) {
|
bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) {
|
||||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||||
typedef cub::BlockScan<bst_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
typedef cub::BlockScan<bst_gpair_precise, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||||
BlockScanT;
|
BlockScanT;
|
||||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||||
typedef cub::BlockReduce<bst_gpair, BLOCK_THREADS> SumReduceT;
|
typedef cub::BlockReduce<bst_gpair_precise, BLOCK_THREADS> SumReduceT;
|
||||||
|
|
||||||
union TempStorage {
|
union TempStorage {
|
||||||
typename BlockScanT::TempStorage scan;
|
typename BlockScanT::TempStorage scan;
|
||||||
@ -461,12 +477,12 @@ __global__ void find_split_kernel(
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
||||||
struct UninitializedGpair : cub::Uninitialized<bst_gpair> {};
|
struct UninitializedGpair : cub::Uninitialized<bst_gpair_precise> {};
|
||||||
|
|
||||||
__shared__ UninitializedSplit uninitialized_split;
|
__shared__ UninitializedSplit uninitialized_split;
|
||||||
Split& split = uninitialized_split.Alias();
|
Split& split = uninitialized_split.Alias();
|
||||||
__shared__ UninitializedGpair uninitialized_sum;
|
__shared__ UninitializedGpair uninitialized_sum;
|
||||||
bst_gpair& shared_sum = uninitialized_sum.Alias();
|
bst_gpair_precise& shared_sum = uninitialized_sum.Alias();
|
||||||
__shared__ ArgMaxT block_max;
|
__shared__ ArgMaxT block_max;
|
||||||
__shared__ TempStorage temp_storage;
|
__shared__ TempStorage temp_storage;
|
||||||
|
|
||||||
@ -486,15 +502,14 @@ __global__ void find_split_kernel(
|
|||||||
|
|
||||||
int begin = d_feature_segments[level_node_idx * n_features + fidx];
|
int begin = d_feature_segments[level_node_idx * n_features + fidx];
|
||||||
int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
|
int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
|
||||||
int gidx = (begin - (level_node_idx * n_bins)) + threadIdx.x;
|
|
||||||
bool thread_active = threadIdx.x < end - begin;
|
|
||||||
|
|
||||||
bst_gpair feature_sum = bst_gpair();
|
bst_gpair_precise feature_sum = bst_gpair_precise();
|
||||||
for (int reduce_begin = begin; reduce_begin < end;
|
for (int reduce_begin = begin; reduce_begin < end;
|
||||||
reduce_begin += BLOCK_THREADS) {
|
reduce_begin += BLOCK_THREADS) {
|
||||||
|
bool thread_active = reduce_begin + threadIdx.x < end;
|
||||||
// Scan histogram
|
// Scan histogram
|
||||||
bst_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
bst_gpair_precise bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
||||||
: bst_gpair();
|
: bst_gpair_precise();
|
||||||
|
|
||||||
feature_sum +=
|
feature_sum +=
|
||||||
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
||||||
@ -508,17 +523,18 @@ __global__ void find_split_kernel(
|
|||||||
GpairCallbackOp prefix_op = GpairCallbackOp();
|
GpairCallbackOp prefix_op = GpairCallbackOp();
|
||||||
for (int scan_begin = begin; scan_begin < end;
|
for (int scan_begin = begin; scan_begin < end;
|
||||||
scan_begin += BLOCK_THREADS) {
|
scan_begin += BLOCK_THREADS) {
|
||||||
bst_gpair bin =
|
bool thread_active = scan_begin + threadIdx.x < end;
|
||||||
thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair();
|
bst_gpair_precise bin =
|
||||||
|
thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair_precise();
|
||||||
|
|
||||||
BlockScanT(temp_storage.scan)
|
BlockScanT(temp_storage.scan)
|
||||||
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||||
|
|
||||||
// Calculate gain
|
// Calculate gain
|
||||||
bst_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
bst_gpair_precise parent_sum = d_nodes[node_idx].sum_gradients;
|
||||||
float parent_gain = d_nodes[node_idx].root_gain;
|
float parent_gain = d_nodes[node_idx].root_gain;
|
||||||
|
|
||||||
bst_gpair missing = parent_sum - shared_sum;
|
bst_gpair_precise missing = parent_sum - shared_sum;
|
||||||
|
|
||||||
bool missing_left;
|
bool missing_left;
|
||||||
float gain = thread_active
|
float gain = thread_active
|
||||||
@ -541,6 +557,7 @@ __global__ void find_split_kernel(
|
|||||||
// Best thread updates split
|
// Best thread updates split
|
||||||
if (threadIdx.x == block_max.key) {
|
if (threadIdx.x == block_max.key) {
|
||||||
float fvalue;
|
float fvalue;
|
||||||
|
int gidx = (scan_begin - (level_node_idx * n_bins)) + threadIdx.x;
|
||||||
if (threadIdx.x == 0 &&
|
if (threadIdx.x == 0 &&
|
||||||
begin == scan_begin) { // check at start of first tile
|
begin == scan_begin) { // check at start of first tile
|
||||||
fvalue = d_fidx_min_map[fidx];
|
fvalue = d_fidx_min_map[fidx];
|
||||||
@ -548,8 +565,8 @@ __global__ void find_split_kernel(
|
|||||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_gpair left = missing_left ? bin + missing : bin;
|
bst_gpair_precise left = missing_left ? bin + missing : bin;
|
||||||
bst_gpair right = parent_sum - left;
|
bst_gpair_precise right = parent_sum - left;
|
||||||
|
|
||||||
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||||
}
|
}
|
||||||
@ -662,7 +679,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int nodes_offset_device = d_idx * num_nodes_device;
|
int nodes_offset_device = d_idx * num_nodes_device;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
(const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
|
||||||
nodes_child_temp[d_idx].data(), nodes_offset_device,
|
nodes_child_temp[d_idx].data(), nodes_offset_device,
|
||||||
@ -759,7 +776,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int nodes_offset_device = d_idx * num_nodes_device;
|
int nodes_offset_device = d_idx * num_nodes_device;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
(const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
||||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
@ -812,7 +829,7 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int nodes_offset_device = 0;
|
int nodes_offset_device = 0;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
(const bst_gpair_precise*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
||||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
@ -958,7 +975,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
|||||||
|
|
||||||
dh::TransformLbs(
|
dh::TransformLbs(
|
||||||
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
|
device_idx, &temp_memory[d_idx], element_end - element_begin, d_row_ptr,
|
||||||
row_end - row_begin, [=] __device__(size_t local_idx, int local_ridx) {
|
row_end - row_begin, is_dense, [=] __device__(size_t local_idx, int local_ridx) {
|
||||||
int pos = d_position[local_ridx];
|
int pos = d_position[local_ridx];
|
||||||
if (!is_active(pos, depth)) {
|
if (!is_active(pos, depth)) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@ -25,16 +25,16 @@ struct DeviceGMat {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct HistBuilder {
|
struct HistBuilder {
|
||||||
bst_gpair *d_hist;
|
bst_gpair_precise *d_hist;
|
||||||
int n_bins;
|
int n_bins;
|
||||||
__host__ __device__ HistBuilder(bst_gpair *ptr, int n_bins);
|
__host__ __device__ HistBuilder(bst_gpair_precise *ptr, int n_bins);
|
||||||
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const;
|
__device__ void Add(bst_gpair_precise gpair, int gidx, int nidx) const;
|
||||||
__device__ bst_gpair Get(int gidx, int nidx) const;
|
__device__ bst_gpair_precise Get(int gidx, int nidx) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceHist {
|
struct DeviceHist {
|
||||||
int n_bins;
|
int n_bins;
|
||||||
dh::dvec<bst_gpair> data;
|
dh::dvec<bst_gpair_precise> data;
|
||||||
|
|
||||||
void Init(int max_depth);
|
void Init(int max_depth);
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ struct DeviceHist {
|
|||||||
|
|
||||||
HistBuilder GetBuilder();
|
HistBuilder GetBuilder();
|
||||||
|
|
||||||
bst_gpair *GetLevelPtr(int depth);
|
bst_gpair_precise *GetLevelPtr(int depth);
|
||||||
|
|
||||||
int LevelSize(int depth);
|
int LevelSize(int depth);
|
||||||
};
|
};
|
||||||
|
|||||||
@ -82,22 +82,16 @@ class TestGPU(unittest.TestCase):
|
|||||||
|
|
||||||
num_rounds = 1
|
num_rounds = 1
|
||||||
|
|
||||||
eprint("normal updater")
|
|
||||||
xgb.train(ag_param, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
|
||||||
evals_result=ag_res)
|
|
||||||
eprint("hist updater")
|
eprint("hist updater")
|
||||||
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
xgb.train(ag_paramb, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||||
evals_result=ag_resb)
|
evals_result=ag_resb)
|
||||||
eprint("gpu_hist updater 1 gpu")
|
eprint("gpu_hist updater 1 gpu")
|
||||||
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
xgb.train(ag_param2, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||||
evals_result=ag_res2)
|
evals_result=ag_res2)
|
||||||
eprint("gpu_hist updater all gpus")
|
eprint("gpu_hist updater all gpus")
|
||||||
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train'), (ag_dtest, 'test')],
|
xgb.train(ag_param3, ag_dtrain, num_rounds, [(ag_dtrain, 'train')],
|
||||||
evals_result=ag_res3)
|
evals_result=ag_res3)
|
||||||
|
|
||||||
assert np.fabs(ag_res['train']['auc'][0] - ag_resb['train']['auc'][0])<0.001
|
|
||||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res2['train']['auc'][0])<0.001
|
|
||||||
assert np.fabs(ag_res['train']['auc'][0] - ag_res3['train']['auc'][0])<0.001
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -117,7 +117,6 @@ struct GBTreeModel {
|
|||||||
}
|
}
|
||||||
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
|
||||||
int bst_group) {
|
int bst_group) {
|
||||||
size_t old_ntree = trees.size();
|
|
||||||
for (size_t i = 0; i < new_trees.size(); ++i) {
|
for (size_t i = 0; i < new_trees.size(); ++i) {
|
||||||
trees.push_back(std::move(new_trees[i]));
|
trees.push_back(std::move(new_trees[i]));
|
||||||
tree_info.push_back(bst_group);
|
tree_info.push_back(bst_group);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user