86 lines
3.0 KiB
Plaintext
86 lines
3.0 KiB
Plaintext
/*!
|
|
* Copyright 2016 Rory mitchell
|
|
*/
|
|
#pragma once
|
|
#include <cub/cub.cuh>
|
|
#include <xgboost/base.h>
|
|
#include "device_helpers.cuh"
|
|
#include "find_split_multiscan.cuh"
|
|
#include "find_split_sorting.cuh"
|
|
#include "types_functions.cuh"
|
|
|
|
namespace xgboost {
|
|
namespace tree {
|
|
|
|
__global__ void
|
|
reduce_split_candidates_kernel(Split *d_split_candidates, Node *d_current_nodes,
|
|
Node *d_new_nodes, int n_current_nodes,
|
|
int n_features, const GPUTrainingParam param) {
|
|
int nid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
if (nid >= n_current_nodes) {
|
|
return;
|
|
}
|
|
|
|
// Find best split for each node
|
|
Split best;
|
|
|
|
for (int i = 0; i < n_features; i++) {
|
|
best.Update(d_split_candidates[n_current_nodes * i + nid]);
|
|
}
|
|
|
|
// Update current node
|
|
d_current_nodes[nid].split = best;
|
|
|
|
// Generate new nodes
|
|
d_new_nodes[nid * 2] =
|
|
Node(best.left_sum,
|
|
CalcGain(param, best.left_sum.grad(), best.left_sum.hess()),
|
|
CalcWeight(param, best.left_sum.grad(), best.left_sum.hess()));
|
|
d_new_nodes[nid * 2 + 1] =
|
|
Node(best.right_sum,
|
|
CalcGain(param, best.right_sum.grad(), best.right_sum.hess()),
|
|
CalcWeight(param, best.right_sum.grad(), best.right_sum.hess()));
|
|
}
|
|
|
|
void reduce_split_candidates(Split *d_split_candidates, Node *d_nodes,
|
|
int level, int n_features,
|
|
const GPUTrainingParam param) {
|
|
// Current level nodes (before split)
|
|
Node *d_current_nodes = d_nodes + (1 << (level)) - 1;
|
|
// Next level nodes (after split)
|
|
Node *d_new_nodes = d_nodes + (1 << (level + 1)) - 1;
|
|
// Number of existing nodes on this level
|
|
int n_current_nodes = 1 << level;
|
|
|
|
const int BLOCK_THREADS = 512;
|
|
const int GRID_SIZE = dh::div_round_up(n_current_nodes, BLOCK_THREADS);
|
|
|
|
reduce_split_candidates_kernel<<<GRID_SIZE, BLOCK_THREADS>>>(
|
|
d_split_candidates, d_current_nodes, d_new_nodes, n_current_nodes,
|
|
n_features, param);
|
|
dh::safe_cuda(cudaDeviceSynchronize());
|
|
}
|
|
|
|
void find_split(const ItemIter items_iter, Split *d_split_candidates,
|
|
Node *d_nodes, bst_uint num_items, int num_features,
|
|
const int *d_feature_offsets, gpu_gpair *d_node_sums,
|
|
int *d_node_offsets, const GPUTrainingParam param,
|
|
const int level, bool multiscan_algorithm) {
|
|
if (multiscan_algorithm) {
|
|
find_split_candidates_multiscan(items_iter, d_split_candidates, d_nodes,
|
|
num_items, num_features, d_feature_offsets,
|
|
param, level);
|
|
} else {
|
|
find_split_candidates_sorted(items_iter, d_split_candidates, d_nodes,
|
|
num_items, num_features, d_feature_offsets,
|
|
d_node_sums, d_node_offsets, param, level);
|
|
}
|
|
|
|
// Find the best split for each node
|
|
reduce_split_candidates(d_split_candidates, d_nodes, level, num_features,
|
|
param);
|
|
}
|
|
} // namespace tree
|
|
} // namespace xgboost
|