[GPU-Plugin] Unify gpu_gpair/bst_gpair. Refactor. (#2477)
This commit is contained in:
parent
d535340459
commit
5f1b0bb386
@ -19,7 +19,8 @@
|
|||||||
/*!
|
/*!
|
||||||
* \brief Whether always log console message with time.
|
* \brief Whether always log console message with time.
|
||||||
* It will display like, with timestamp appended to head of the message.
|
* It will display like, with timestamp appended to head of the message.
|
||||||
* "[21:47:50] 6513x126 matrix with 143286 entries loaded from ../data/agaricus.txt.train"
|
* "[21:47:50] 6513x126 matrix with 143286 entries loaded from
|
||||||
|
* ../data/agaricus.txt.train"
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_LOG_WITH_TIME
|
#ifndef XGBOOST_LOG_WITH_TIME
|
||||||
#define XGBOOST_LOG_WITH_TIME 1
|
#define XGBOOST_LOG_WITH_TIME 1
|
||||||
@ -36,7 +37,7 @@
|
|||||||
* \brief Whether to customize global PRNG.
|
* \brief Whether to customize global PRNG.
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
#ifndef XGBOOST_CUSTOMIZE_GLOBAL_PRNG
|
||||||
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
|
#define XGBOOST_CUSTOMIZE_GLOBAL_PRNG XGBOOST_STRICT_R_MODE
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*!
|
/*!
|
||||||
@ -48,16 +49,27 @@
|
|||||||
#define XGBOOST_ALIGNAS(X)
|
#define XGBOOST_ALIGNAS(X)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 && !defined(__CUDACC__)
|
#if defined(__GNUC__) && __GNUC__ == 4 && __GNUC_MINOR__ >= 8 && \
|
||||||
|
!defined(__CUDACC__)
|
||||||
#include <parallel/algorithm>
|
#include <parallel/algorithm>
|
||||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
|
#define XGBOOST_PARALLEL_SORT(X, Y, Z) __gnu_parallel::sort((X), (Y), (Z))
|
||||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) __gnu_parallel::stable_sort((X), (Y), (Z))
|
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) \
|
||||||
|
__gnu_parallel::stable_sort((X), (Y), (Z))
|
||||||
#else
|
#else
|
||||||
#define XGBOOST_PARALLEL_SORT(X, Y, Z) std::sort((X), (Y), (Z))
|
#define XGBOOST_PARALLEL_SORT(X, Y, Z) std::sort((X), (Y), (Z))
|
||||||
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
|
#define XGBOOST_PARALLEL_STABLE_SORT(X, Y, Z) std::stable_sort((X), (Y), (Z))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/*! \brief namespace of xgboo st*/
|
/*!
|
||||||
|
* \brief Tag function as usable by device
|
||||||
|
*/
|
||||||
|
#ifdef __NVCC__
|
||||||
|
#define XGBOOST_DEVICE __host__ __device__
|
||||||
|
#else
|
||||||
|
#define XGBOOST_DEVICE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*! \brief namespace of xgboost*/
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/*!
|
/*!
|
||||||
* \brief unsigned integer type used in boost,
|
* \brief unsigned integer type used in boost,
|
||||||
@ -76,8 +88,41 @@ struct bst_gpair {
|
|||||||
bst_float grad;
|
bst_float grad;
|
||||||
/*! \brief second order gradient statistics */
|
/*! \brief second order gradient statistics */
|
||||||
bst_float hess;
|
bst_float hess;
|
||||||
bst_gpair() {}
|
|
||||||
bst_gpair(bst_float grad, bst_float hess) : grad(grad), hess(hess) {}
|
XGBOOST_DEVICE bst_gpair() : grad(0), hess(0) {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair(bst_float grad, bst_float hess)
|
||||||
|
: grad(grad), hess(hess) {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair &operator+=(const bst_gpair &rhs) {
|
||||||
|
grad += rhs.grad;
|
||||||
|
hess += rhs.hess;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair operator+(const bst_gpair &rhs) const {
|
||||||
|
bst_gpair g;
|
||||||
|
g.grad = grad + rhs.grad;
|
||||||
|
g.hess = hess + rhs.hess;
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair &operator-=(const bst_gpair &rhs) {
|
||||||
|
grad -= rhs.grad;
|
||||||
|
hess -= rhs.hess;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair operator-(const bst_gpair &rhs) const {
|
||||||
|
bst_gpair g;
|
||||||
|
g.grad = grad - rhs.grad;
|
||||||
|
g.hess = hess - rhs.hess;
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair(int value) {
|
||||||
|
*this = bst_gpair(static_cast<float>(value), static_cast<float>(value));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief small eps gap for minimum split decision. */
|
/*! \brief small eps gap for minimum split decision. */
|
||||||
|
|||||||
@ -15,33 +15,29 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
// When we split on a value which has no left neighbour, define its left
|
|
||||||
// neighbour as having left_fvalue = current_fvalue - FVALUE_EPS
|
|
||||||
// This produces a split value slightly lower than the current instance
|
|
||||||
#define FVALUE_EPS 0.0001
|
|
||||||
|
|
||||||
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
__device__ inline float device_calc_loss_chg(const GPUTrainingParam& param,
|
||||||
const gpu_gpair& scan,
|
const bst_gpair& scan,
|
||||||
const gpu_gpair& missing,
|
const bst_gpair& missing,
|
||||||
const gpu_gpair& parent_sum,
|
const bst_gpair& parent_sum,
|
||||||
const float& parent_gain,
|
const float& parent_gain,
|
||||||
bool missing_left) {
|
bool missing_left) {
|
||||||
gpu_gpair left = scan;
|
bst_gpair left = scan;
|
||||||
|
|
||||||
if (missing_left) {
|
if (missing_left) {
|
||||||
left += missing;
|
left += missing;
|
||||||
}
|
}
|
||||||
|
|
||||||
gpu_gpair right = parent_sum - left;
|
bst_gpair right = parent_sum - left;
|
||||||
|
|
||||||
float left_gain = CalcGain(param, left.grad(), left.hess());
|
float left_gain = CalcGain(param, left.grad, left.hess);
|
||||||
float right_gain = CalcGain(param, right.grad(), right.hess());
|
float right_gain = CalcGain(param, right.grad, right.hess);
|
||||||
return left_gain + right_gain - parent_gain;
|
return left_gain + right_gain - parent_gain;
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ float inline loss_chg_missing(const gpu_gpair& scan,
|
__device__ float inline loss_chg_missing(const bst_gpair& scan,
|
||||||
const gpu_gpair& missing,
|
const bst_gpair& missing,
|
||||||
const gpu_gpair& parent_sum,
|
const bst_gpair& parent_sum,
|
||||||
const float& parent_gain,
|
const float& parent_gain,
|
||||||
const GPUTrainingParam& param,
|
const GPUTrainingParam& param,
|
||||||
bool& missing_left_out) { // NOLINT
|
bool& missing_left_out) { // NOLINT
|
||||||
@ -134,39 +130,39 @@ inline void dense2sparse_tree(RegTree* p_tree,
|
|||||||
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
|
tree[nid].set_split(n.split.findex, n.split.fvalue, n.split.missing_left);
|
||||||
tree.stat(nid).loss_chg = n.split.loss_chg;
|
tree.stat(nid).loss_chg = n.split.loss_chg;
|
||||||
tree.stat(nid).base_weight = n.weight;
|
tree.stat(nid).base_weight = n.weight;
|
||||||
tree.stat(nid).sum_hess = n.sum_gradients.hess();
|
tree.stat(nid).sum_hess = n.sum_gradients.hess;
|
||||||
tree[tree[nid].cleft()].set_leaf(0);
|
tree[tree[nid].cleft()].set_leaf(0);
|
||||||
tree[tree[nid].cright()].set_leaf(0);
|
tree[tree[nid].cright()].set_leaf(0);
|
||||||
nid++;
|
nid++;
|
||||||
} else if (flag == LEAF) {
|
} else if (flag == LEAF) {
|
||||||
tree[nid].set_leaf(n.weight * param.learning_rate);
|
tree[nid].set_leaf(n.weight * param.learning_rate);
|
||||||
tree.stat(nid).sum_hess = n.sum_gradients.hess();
|
tree.stat(nid).sum_hess = n.sum_gradients.hess;
|
||||||
nid++;
|
nid++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample,
|
inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample,
|
||||||
int offset) {
|
int offset) {
|
||||||
if (subsample == 1.0) {
|
if (subsample == 1.0) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
dh::dvec<gpu_gpair>& gpair = *p_gpair;
|
dh::dvec<bst_gpair>& gpair = *p_gpair;
|
||||||
|
|
||||||
auto d_gpair = gpair.data();
|
auto d_gpair = gpair.data();
|
||||||
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
|
dh::BernoulliRng rng(subsample, common::GlobalRandom()());
|
||||||
|
|
||||||
dh::launch_n(gpair.device_idx(), gpair.size(), [=] __device__(int i) {
|
dh::launch_n(gpair.device_idx(), gpair.size(), [=] __device__(int i) {
|
||||||
if (!rng(i + offset)) {
|
if (!rng(i + offset)) {
|
||||||
d_gpair[i] = gpu_gpair();
|
d_gpair[i] = bst_gpair();
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set gradient pair to 0 with p = 1 - subsample
|
// Set gradient pair to 0 with p = 1 - subsample
|
||||||
inline void subsample_gpair(dh::dvec<gpu_gpair>* p_gpair, float subsample) {
|
inline void subsample_gpair(dh::dvec<bst_gpair>* p_gpair, float subsample) {
|
||||||
int offset = 0;
|
int offset = 0;
|
||||||
subsample_gpair(p_gpair, subsample, offset);
|
subsample_gpair(p_gpair, subsample, offset);
|
||||||
}
|
}
|
||||||
@ -182,11 +178,11 @@ inline std::vector<int> col_sample(std::vector<int> features, float colsample) {
|
|||||||
}
|
}
|
||||||
struct GpairCallbackOp {
|
struct GpairCallbackOp {
|
||||||
// Running prefix
|
// Running prefix
|
||||||
gpu_gpair running_total;
|
bst_gpair running_total;
|
||||||
// Constructor
|
// Constructor
|
||||||
__device__ GpairCallbackOp() : running_total(gpu_gpair()) {}
|
__device__ GpairCallbackOp() : running_total(bst_gpair()) {}
|
||||||
__device__ gpu_gpair operator()(gpu_gpair block_aggregate) {
|
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
|
||||||
gpu_gpair old_prefix = running_total;
|
bst_gpair old_prefix = running_total;
|
||||||
running_total += block_aggregate;
|
running_total += block_aggregate;
|
||||||
return old_prefix;
|
return old_prefix;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,8 +17,8 @@
|
|||||||
|
|
||||||
#include "../../../../src/tree/param.h"
|
#include "../../../../src/tree/param.h"
|
||||||
#include "../common.cuh"
|
#include "../common.cuh"
|
||||||
#include "loss_functions.cuh"
|
|
||||||
#include "node.cuh"
|
#include "node.cuh"
|
||||||
|
#include "../types.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -66,10 +66,10 @@ DEV_INLINE void atomicArgMax(Split* address, Split val) {
|
|||||||
|
|
||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
DEV_INLINE void argMaxWithAtomics(
|
DEV_INLINE void argMaxWithAtomics(
|
||||||
int id, Split* nodeSplits, const gpu_gpair* gradScans,
|
int id, Split* nodeSplits, const bst_gpair* gradScans,
|
||||||
const gpu_gpair* gradSums, const float* vals, const int* colIds,
|
const bst_gpair* gradSums, const float* vals, const int* colIds,
|
||||||
const node_id_t* nodeAssigns, const Node<node_id_t>* nodes, int nUniqKeys,
|
const node_id_t* nodeAssigns, const Node<node_id_t>* nodes, int nUniqKeys,
|
||||||
node_id_t nodeStart, int len, const TrainParam& param) {
|
node_id_t nodeStart, int len, const GPUTrainingParam& param) {
|
||||||
int nodeId = nodeAssigns[id];
|
int nodeId = nodeAssigns[id];
|
||||||
///@todo: this is really a bad check! but will be fixed when we move
|
///@todo: this is really a bad check! but will be fixed when we move
|
||||||
/// to key-based reduction
|
/// to key-based reduction
|
||||||
@ -78,14 +78,14 @@ DEV_INLINE void argMaxWithAtomics(
|
|||||||
(vals[id] == vals[id - 1]))) {
|
(vals[id] == vals[id - 1]))) {
|
||||||
if (nodeId != UNUSED_NODE) {
|
if (nodeId != UNUSED_NODE) {
|
||||||
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys);
|
int sumId = abs2uniqKey(id, nodeAssigns, colIds, nodeStart, nUniqKeys);
|
||||||
gpu_gpair colSum = gradSums[sumId];
|
bst_gpair colSum = gradSums[sumId];
|
||||||
int uid = nodeId - nodeStart;
|
int uid = nodeId - nodeStart;
|
||||||
Node<node_id_t> n = nodes[nodeId];
|
Node<node_id_t> n = nodes[nodeId];
|
||||||
gpu_gpair parentSum = n.gradSum;
|
bst_gpair parentSum = n.gradSum;
|
||||||
float parentGain = n.score;
|
float parentGain = n.score;
|
||||||
bool tmp;
|
bool tmp;
|
||||||
Split s;
|
Split s;
|
||||||
gpu_gpair missing = parentSum - colSum;
|
bst_gpair missing = parentSum - colSum;
|
||||||
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
|
s.score = loss_chg_missing(gradScans[id], missing, parentSum, parentGain,
|
||||||
param, tmp);
|
param, tmp);
|
||||||
s.index = id;
|
s.index = id;
|
||||||
@ -96,7 +96,7 @@ DEV_INLINE void argMaxWithAtomics(
|
|||||||
|
|
||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
__global__ void atomicArgMaxByKeyGmem(
|
__global__ void atomicArgMaxByKeyGmem(
|
||||||
Split* nodeSplits, const gpu_gpair* gradScans, const gpu_gpair* gradSums,
|
Split* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums,
|
||||||
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
|
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
|
||||||
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
|
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
|
||||||
const TrainParam param) {
|
const TrainParam param) {
|
||||||
@ -104,13 +104,13 @@ __global__ void atomicArgMaxByKeyGmem(
|
|||||||
const int stride = blockDim.x * gridDim.x;
|
const int stride = blockDim.x * gridDim.x;
|
||||||
for (; id < len; id += stride) {
|
for (; id < len; id += stride) {
|
||||||
argMaxWithAtomics(id, nodeSplits, gradScans, gradSums, vals, colIds,
|
argMaxWithAtomics(id, nodeSplits, gradScans, gradSums, vals, colIds,
|
||||||
nodeAssigns, nodes, nUniqKeys, nodeStart, len, param);
|
nodeAssigns, nodes, nUniqKeys, nodeStart, len, GPUTrainingParam(param));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
__global__ void atomicArgMaxByKeySmem(
|
__global__ void atomicArgMaxByKeySmem(
|
||||||
Split* nodeSplits, const gpu_gpair* gradScans, const gpu_gpair* gradSums,
|
Split* nodeSplits, const bst_gpair* gradScans, const bst_gpair* gradSums,
|
||||||
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
|
const float* vals, const int* colIds, const node_id_t* nodeAssigns,
|
||||||
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
|
const Node<node_id_t>* nodes, int nUniqKeys, node_id_t nodeStart, int len,
|
||||||
const TrainParam param) {
|
const TrainParam param) {
|
||||||
@ -153,8 +153,8 @@ __global__ void atomicArgMaxByKeySmem(
|
|||||||
* @param algo which algorithm to use for argmax_by_key
|
* @param algo which algorithm to use for argmax_by_key
|
||||||
*/
|
*/
|
||||||
template <typename node_id_t, int BLKDIM = 256, int ITEMS_PER_THREAD = 4>
|
template <typename node_id_t, int BLKDIM = 256, int ITEMS_PER_THREAD = 4>
|
||||||
void argMaxByKey(Split* nodeSplits, const gpu_gpair* gradScans,
|
void argMaxByKey(Split* nodeSplits, const bst_gpair* gradScans,
|
||||||
const gpu_gpair* gradSums, const float* vals,
|
const bst_gpair* gradSums, const float* vals,
|
||||||
const int* colIds, const node_id_t* nodeAssigns,
|
const int* colIds, const node_id_t* nodeAssigns,
|
||||||
const Node<node_id_t>* nodes, int nUniqKeys,
|
const Node<node_id_t>* nodes, int nUniqKeys,
|
||||||
node_id_t nodeStart, int len, const TrainParam param,
|
node_id_t nodeStart, int len, const TrainParam param,
|
||||||
|
|||||||
@ -16,7 +16,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "../common.cuh"
|
#include "../common.cuh"
|
||||||
#include "gradients.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -24,11 +23,11 @@ namespace exact {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @struct Pair fused_scan_reduce_by_key.cuh
|
* @struct Pair fused_scan_reduce_by_key.cuh
|
||||||
* @brief Pair used for key basd scan operations on gpu_gpair
|
* @brief Pair used for key basd scan operations on bst_gpair
|
||||||
*/
|
*/
|
||||||
struct Pair {
|
struct Pair {
|
||||||
int key;
|
int key;
|
||||||
gpu_gpair value;
|
bst_gpair value;
|
||||||
};
|
};
|
||||||
|
|
||||||
/** define a key that's not used at all in the entire boosting process */
|
/** define a key that's not used at all in the entire boosting process */
|
||||||
@ -61,15 +60,27 @@ struct AddByKey {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Gradient value getter function
|
||||||
|
* @param id the index into the vals or instIds array to which to fetch
|
||||||
|
* @param vals the gradient value buffer
|
||||||
|
* @param instIds instance index buffer
|
||||||
|
* @return the expected gradient value
|
||||||
|
*/
|
||||||
|
HOST_DEV_INLINE bst_gpair get(int id, const bst_gpair* vals, const int* instIds) {
|
||||||
|
id = instIds[id];
|
||||||
|
return vals[id];
|
||||||
|
}
|
||||||
|
|
||||||
template <typename node_id_t, int BLKDIM_L1L3>
|
template <typename node_id_t, int BLKDIM_L1L3>
|
||||||
__global__ void cubScanByKeyL1(gpu_gpair* scans, const gpu_gpair* vals,
|
__global__ void cubScanByKeyL1(bst_gpair* scans, const bst_gpair* vals,
|
||||||
const int* instIds, gpu_gpair* mScans,
|
const int* instIds, bst_gpair* mScans,
|
||||||
int* mKeys, const node_id_t* keys, int nUniqKeys,
|
int* mKeys, const node_id_t* keys, int nUniqKeys,
|
||||||
const int* colIds, node_id_t nodeStart,
|
const int* colIds, node_id_t nodeStart,
|
||||||
const int size) {
|
const int size) {
|
||||||
Pair rootPair = {NONE_KEY, gpu_gpair(0.f, 0.f)};
|
Pair rootPair = {NONE_KEY, bst_gpair(0.f, 0.f)};
|
||||||
int myKey;
|
int myKey;
|
||||||
gpu_gpair myValue;
|
bst_gpair myValue;
|
||||||
typedef cub::BlockScan<Pair, BLKDIM_L1L3> BlockScan;
|
typedef cub::BlockScan<Pair, BLKDIM_L1L3> BlockScan;
|
||||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||||
Pair threadData;
|
Pair threadData;
|
||||||
@ -98,14 +109,14 @@ __global__ void cubScanByKeyL1(gpu_gpair* scans, const gpu_gpair* vals,
|
|||||||
}
|
}
|
||||||
if (threadIdx.x == BLKDIM_L1L3 - 1) {
|
if (threadIdx.x == BLKDIM_L1L3 - 1) {
|
||||||
threadData.value =
|
threadData.value =
|
||||||
(myKey == previousKey) ? threadData.value : gpu_gpair(0.0f, 0.0f);
|
(myKey == previousKey) ? threadData.value : bst_gpair(0.0f, 0.0f);
|
||||||
mKeys[blockIdx.x] = myKey;
|
mKeys[blockIdx.x] = myKey;
|
||||||
mScans[blockIdx.x] = threadData.value + myValue;
|
mScans[blockIdx.x] = threadData.value + myValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int BLKSIZE>
|
template <int BLKSIZE>
|
||||||
__global__ void cubScanByKeyL2(gpu_gpair* mScans, int* mKeys, int mLength) {
|
__global__ void cubScanByKeyL2(bst_gpair* mScans, int* mKeys, int mLength) {
|
||||||
typedef cub::BlockScan<Pair, BLKSIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
|
typedef cub::BlockScan<Pair, BLKSIZE, cub::BLOCK_SCAN_WARP_SCANS> BlockScan;
|
||||||
Pair threadData;
|
Pair threadData;
|
||||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||||
@ -119,9 +130,9 @@ __global__ void cubScanByKeyL2(gpu_gpair* mScans, int* mKeys, int mLength) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename node_id_t, int BLKDIM_L1L3>
|
template <typename node_id_t, int BLKDIM_L1L3>
|
||||||
__global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
|
__global__ void cubScanByKeyL3(bst_gpair* sums, bst_gpair* scans,
|
||||||
const gpu_gpair* vals, const int* instIds,
|
const bst_gpair* vals, const int* instIds,
|
||||||
const gpu_gpair* mScans, const int* mKeys,
|
const bst_gpair* mScans, const int* mKeys,
|
||||||
const node_id_t* keys, int nUniqKeys,
|
const node_id_t* keys, int nUniqKeys,
|
||||||
const int* colIds, node_id_t nodeStart,
|
const int* colIds, node_id_t nodeStart,
|
||||||
const int size) {
|
const int size) {
|
||||||
@ -130,19 +141,19 @@ __global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
|
|||||||
// to avoid the following warning from nvcc:
|
// to avoid the following warning from nvcc:
|
||||||
// __shared__ memory variable with non-empty constructor or destructor
|
// __shared__ memory variable with non-empty constructor or destructor
|
||||||
// (potential race between threads)
|
// (potential race between threads)
|
||||||
__shared__ char gradBuff[sizeof(gpu_gpair)];
|
__shared__ char gradBuff[sizeof(bst_gpair)];
|
||||||
__shared__ int s_mKeys;
|
__shared__ int s_mKeys;
|
||||||
gpu_gpair* s_mScans = (gpu_gpair*)gradBuff;
|
bst_gpair* s_mScans = (bst_gpair*)gradBuff;
|
||||||
if (tid >= size) return;
|
if (tid >= size) return;
|
||||||
// cache block-wide partial scan info
|
// cache block-wide partial scan info
|
||||||
if (relId == 0) {
|
if (relId == 0) {
|
||||||
s_mKeys = (blockIdx.x > 0) ? mKeys[blockIdx.x - 1] : NONE_KEY;
|
s_mKeys = (blockIdx.x > 0) ? mKeys[blockIdx.x - 1] : NONE_KEY;
|
||||||
s_mScans[0] = (blockIdx.x > 0) ? mScans[blockIdx.x - 1] : gpu_gpair();
|
s_mScans[0] = (blockIdx.x > 0) ? mScans[blockIdx.x - 1] : bst_gpair();
|
||||||
}
|
}
|
||||||
int myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
|
int myKey = abs2uniqKey(tid, keys, colIds, nodeStart, nUniqKeys);
|
||||||
int previousKey = tid == 0 ? NONE_KEY : abs2uniqKey(tid - 1, keys, colIds,
|
int previousKey = tid == 0 ? NONE_KEY : abs2uniqKey(tid - 1, keys, colIds,
|
||||||
nodeStart, nUniqKeys);
|
nodeStart, nUniqKeys);
|
||||||
gpu_gpair myValue = scans[tid];
|
bst_gpair myValue = scans[tid];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (blockIdx.x > 0 && s_mKeys == previousKey) {
|
if (blockIdx.x > 0 && s_mKeys == previousKey) {
|
||||||
myValue += s_mScans[0];
|
myValue += s_mScans[0];
|
||||||
@ -152,7 +163,7 @@ __global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
|
|||||||
}
|
}
|
||||||
if ((previousKey != myKey) && (previousKey >= 0)) {
|
if ((previousKey != myKey) && (previousKey >= 0)) {
|
||||||
sums[previousKey] = myValue;
|
sums[previousKey] = myValue;
|
||||||
myValue = gpu_gpair(0.0f, 0.0f);
|
myValue = bst_gpair(0.0f, 0.0f);
|
||||||
}
|
}
|
||||||
scans[tid] = myValue;
|
scans[tid] = myValue;
|
||||||
}
|
}
|
||||||
@ -178,12 +189,12 @@ __global__ void cubScanByKeyL3(gpu_gpair* sums, gpu_gpair* scans,
|
|||||||
* @param nodeStart index of the leftmost node in the current level
|
* @param nodeStart index of the leftmost node in the current level
|
||||||
*/
|
*/
|
||||||
template <typename node_id_t, int BLKDIM_L1L3 = 256, int BLKDIM_L2 = 512>
|
template <typename node_id_t, int BLKDIM_L1L3 = 256, int BLKDIM_L2 = 512>
|
||||||
void reduceScanByKey(gpu_gpair* sums, gpu_gpair* scans, const gpu_gpair* vals,
|
void reduceScanByKey(bst_gpair* sums, bst_gpair* scans, const bst_gpair* vals,
|
||||||
const int* instIds, const node_id_t* keys, int size,
|
const int* instIds, const node_id_t* keys, int size,
|
||||||
int nUniqKeys, int nCols, gpu_gpair* tmpScans,
|
int nUniqKeys, int nCols, bst_gpair* tmpScans,
|
||||||
int* tmpKeys, const int* colIds, node_id_t nodeStart) {
|
int* tmpKeys, const int* colIds, node_id_t nodeStart) {
|
||||||
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
|
int nBlks = dh::div_round_up(size, BLKDIM_L1L3);
|
||||||
cudaMemset(sums, 0, nUniqKeys * nCols * sizeof(gpu_gpair));
|
cudaMemset(sums, 0, nUniqKeys * nCols * sizeof(bst_gpair));
|
||||||
cubScanByKeyL1<node_id_t, BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>(
|
cubScanByKeyL1<node_id_t, BLKDIM_L1L3><<<nBlks, BLKDIM_L1L3>>>(
|
||||||
scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
|
scans, vals, instIds, tmpScans, tmpKeys, keys, nUniqKeys, colIds,
|
||||||
nodeStart, size);
|
nodeStart, size);
|
||||||
|
|||||||
@ -19,13 +19,11 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
#include "../../../../src/tree/param.h"
|
#include "../../../../src/tree/param.h"
|
||||||
#include "../common.cuh"
|
#include "../common.cuh"
|
||||||
#include "argmax_by_key.cuh"
|
#include <vector>
|
||||||
#include "cub/cub.cuh"
|
|
||||||
#include "fused_scan_reduce_by_key.cuh"
|
|
||||||
#include "gradients.cuh"
|
|
||||||
#include "loss_functions.cuh"
|
|
||||||
#include "node.cuh"
|
#include "node.cuh"
|
||||||
#include "split2node.cuh"
|
#include "split2node.cuh"
|
||||||
|
#include "argmax_by_key.cuh"
|
||||||
|
#include "fused_scan_reduce_by_key.cuh"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -33,13 +31,13 @@ namespace tree {
|
|||||||
namespace exact {
|
namespace exact {
|
||||||
|
|
||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
__global__ void initRootNode(Node<node_id_t>* nodes, const gpu_gpair* sums,
|
__global__ void initRootNode(Node<node_id_t>* nodes, const bst_gpair* sums,
|
||||||
const TrainParam param) {
|
const TrainParam param) {
|
||||||
// gradients already evaluated inside transferGrads
|
// gradients already evaluated inside transferGrads
|
||||||
Node<node_id_t> n;
|
Node<node_id_t> n;
|
||||||
n.gradSum = sums[0];
|
n.gradSum = sums[0];
|
||||||
n.score = CalcGain(param, n.gradSum.g, n.gradSum.h);
|
n.score = CalcGain(param, n.gradSum.grad , n.gradSum.hess);
|
||||||
n.weight = CalcWeight(param, n.gradSum.g, n.gradSum.h);
|
n.weight = CalcWeight(param, n.gradSum.grad , n.gradSum.hess);
|
||||||
n.id = 0;
|
n.id = 0;
|
||||||
nodes[0] = n;
|
nodes[0] = n;
|
||||||
}
|
}
|
||||||
@ -198,13 +196,13 @@ class GPUBuilder {
|
|||||||
dh::dvec<int> instIds_cached;
|
dh::dvec<int> instIds_cached;
|
||||||
/** column offsets for these feature values */
|
/** column offsets for these feature values */
|
||||||
dh::dvec<int> colOffsets;
|
dh::dvec<int> colOffsets;
|
||||||
dh::dvec<gpu_gpair> gradsInst;
|
dh::dvec<bst_gpair> gradsInst;
|
||||||
dh::dvec2<node_id_t> nodeAssigns;
|
dh::dvec2<node_id_t> nodeAssigns;
|
||||||
dh::dvec2<int> nodeLocations;
|
dh::dvec2<int> nodeLocations;
|
||||||
dh::dvec<Node<node_id_t>> nodes;
|
dh::dvec<Node<node_id_t>> nodes;
|
||||||
dh::dvec<node_id_t> nodeAssignsPerInst;
|
dh::dvec<node_id_t> nodeAssignsPerInst;
|
||||||
dh::dvec<gpu_gpair> gradSums;
|
dh::dvec<bst_gpair> gradSums;
|
||||||
dh::dvec<gpu_gpair> gradScans;
|
dh::dvec<bst_gpair> gradScans;
|
||||||
dh::dvec<Split> nodeSplits;
|
dh::dvec<Split> nodeSplits;
|
||||||
int nVals;
|
int nVals;
|
||||||
int nRows;
|
int nRows;
|
||||||
@ -212,7 +210,7 @@ class GPUBuilder {
|
|||||||
int maxNodes;
|
int maxNodes;
|
||||||
int maxLeaves;
|
int maxLeaves;
|
||||||
dh::CubMemory tmp_mem;
|
dh::CubMemory tmp_mem;
|
||||||
dh::dvec<gpu_gpair> tmpScanGradBuff;
|
dh::dvec<bst_gpair> tmpScanGradBuff;
|
||||||
dh::dvec<int> tmpScanKeyBuff;
|
dh::dvec<int> tmpScanKeyBuff;
|
||||||
dh::dvec<int> colIds;
|
dh::dvec<int> colIds;
|
||||||
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
|
dh::bulk_allocator<dh::memory_type::DEVICE> ba;
|
||||||
@ -310,10 +308,10 @@ class GPUBuilder {
|
|||||||
void transferGrads(const std::vector<bst_gpair>& gpair) {
|
void transferGrads(const std::vector<bst_gpair>& gpair) {
|
||||||
// HACK
|
// HACK
|
||||||
dh::safe_cuda(cudaMemcpy(gradsInst.data(), &(gpair[0]),
|
dh::safe_cuda(cudaMemcpy(gradsInst.data(), &(gpair[0]),
|
||||||
sizeof(gpu_gpair) * nRows,
|
sizeof(bst_gpair) * nRows,
|
||||||
cudaMemcpyHostToDevice));
|
cudaMemcpyHostToDevice));
|
||||||
// evaluate the full-grad reduction for the root node
|
// evaluate the full-grad reduction for the root node
|
||||||
sumReduction<gpu_gpair>(tmp_mem, gradsInst, gradSums, nRows);
|
sumReduction<bst_gpair>(tmp_mem, gradsInst, gradSums, nRows);
|
||||||
}
|
}
|
||||||
|
|
||||||
void initNodeData(int level, node_id_t nodeStart, int nNodes) {
|
void initNodeData(int level, node_id_t nodeStart, int nNodes) {
|
||||||
@ -371,13 +369,13 @@ class GPUBuilder {
|
|||||||
const Node<node_id_t>& n = hNodes[i];
|
const Node<node_id_t>& n = hNodes[i];
|
||||||
if ((i != 0) && hNodes[i].isLeaf()) {
|
if ((i != 0) && hNodes[i].isLeaf()) {
|
||||||
tree[nodeId].set_leaf(n.weight * param.learning_rate);
|
tree[nodeId].set_leaf(n.weight * param.learning_rate);
|
||||||
tree.stat(nodeId).sum_hess = n.gradSum.h;
|
tree.stat(nodeId).sum_hess = n.gradSum.hess;
|
||||||
++nodeId;
|
++nodeId;
|
||||||
} else if (!hNodes[i].isUnused()) {
|
} else if (!hNodes[i].isUnused()) {
|
||||||
tree.AddChilds(nodeId);
|
tree.AddChilds(nodeId);
|
||||||
tree[nodeId].set_split(n.colIdx, n.threshold, n.dir == LeftDir);
|
tree[nodeId].set_split(n.colIdx, n.threshold, n.dir == LeftDir);
|
||||||
tree.stat(nodeId).loss_chg = n.score;
|
tree.stat(nodeId).loss_chg = n.score;
|
||||||
tree.stat(nodeId).sum_hess = n.gradSum.h;
|
tree.stat(nodeId).sum_hess = n.gradSum.hess;
|
||||||
tree.stat(nodeId).base_weight = n.weight;
|
tree.stat(nodeId).base_weight = n.weight;
|
||||||
tree[tree[nodeId].cleft()].set_leaf(0);
|
tree[tree[nodeId].cleft()].set_leaf(0);
|
||||||
tree[tree[nodeId].cright()].set_leaf(0);
|
tree[tree[nodeId].cright()].set_leaf(0);
|
||||||
|
|||||||
@ -1,91 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
|
|
||||||
* reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../common.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace tree {
|
|
||||||
namespace exact {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @struct gpu_gpair gradients.cuh
|
|
||||||
* @brief The first/second order gradients for iteratively building the tree
|
|
||||||
*/
|
|
||||||
struct gpu_gpair {
|
|
||||||
/** the 'g_i' as it appears in the xgboost paper */
|
|
||||||
float g;
|
|
||||||
/** the 'h_i' as it appears in the xgboost paper */
|
|
||||||
float h;
|
|
||||||
|
|
||||||
HOST_DEV_INLINE gpu_gpair() : g(0.f), h(0.f) {}
|
|
||||||
HOST_DEV_INLINE gpu_gpair(const float& _g, const float& _h) : g(_g), h(_h) {}
|
|
||||||
HOST_DEV_INLINE gpu_gpair(const gpu_gpair& a) : g(a.g), h(a.h) {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Checks whether the hessian is more than the defined weight
|
|
||||||
* @param minWeight minimum weight to be compared against
|
|
||||||
* @return true if the hessian is greater than the minWeight
|
|
||||||
* @note this is useful in deciding whether to further split to child node
|
|
||||||
*/
|
|
||||||
HOST_DEV_INLINE bool isSplittable(float minWeight) const {
|
|
||||||
return (h > minWeight);
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE gpu_gpair& operator+=(const gpu_gpair& a) {
|
|
||||||
g += a.g;
|
|
||||||
h += a.h;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE gpu_gpair& operator-=(const gpu_gpair& a) {
|
|
||||||
g -= a.g;
|
|
||||||
h -= a.h;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE friend gpu_gpair operator+(const gpu_gpair& a,
|
|
||||||
const gpu_gpair& b) {
|
|
||||||
return gpu_gpair(a.g + b.g, a.h + b.h);
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE friend gpu_gpair operator-(const gpu_gpair& a,
|
|
||||||
const gpu_gpair& b) {
|
|
||||||
return gpu_gpair(a.g - b.g, a.h - b.h);
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE gpu_gpair(int value) {
|
|
||||||
*this = gpu_gpair((float)value, (float)value);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Gradient value getter function
|
|
||||||
* @param id the index into the vals or instIds array to which to fetch
|
|
||||||
* @param vals the gradient value buffer
|
|
||||||
* @param instIds instance index buffer
|
|
||||||
* @return the expected gradient value
|
|
||||||
*/
|
|
||||||
HOST_DEV_INLINE gpu_gpair get(int id, const gpu_gpair* vals,
|
|
||||||
const int* instIds) {
|
|
||||||
id = instIds[id];
|
|
||||||
return vals[id];
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace exact
|
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
@ -1,60 +0,0 @@
|
|||||||
/*
|
|
||||||
* Copyright (c) 2017, NVIDIA CORPORATION, Xgboost contributors. All rights
|
|
||||||
* reserved.
|
|
||||||
*
|
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
* you may not use this file except in compliance with the License.
|
|
||||||
* You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include "../common.cuh"
|
|
||||||
#include "gradients.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace tree {
|
|
||||||
namespace exact {
|
|
||||||
|
|
||||||
HOST_DEV_INLINE float device_calc_loss_chg(
|
|
||||||
const TrainParam ¶m, const gpu_gpair &scan, const gpu_gpair &missing,
|
|
||||||
const gpu_gpair &parent_sum, const float &parent_gain, bool missing_left) {
|
|
||||||
gpu_gpair left = scan;
|
|
||||||
if (missing_left) {
|
|
||||||
left += missing;
|
|
||||||
}
|
|
||||||
gpu_gpair right = parent_sum - left;
|
|
||||||
float left_gain = CalcGain(param, left.g, left.h);
|
|
||||||
float right_gain = CalcGain(param, right.g, right.h);
|
|
||||||
return left_gain + right_gain - parent_gain;
|
|
||||||
}
|
|
||||||
|
|
||||||
HOST_DEV_INLINE float loss_chg_missing(const gpu_gpair &scan,
|
|
||||||
const gpu_gpair &missing,
|
|
||||||
const gpu_gpair &parent_sum,
|
|
||||||
const float &parent_gain,
|
|
||||||
const TrainParam ¶m,
|
|
||||||
bool &missing_left_out) {
|
|
||||||
float missing_left_loss =
|
|
||||||
device_calc_loss_chg(param, scan, missing, parent_sum, parent_gain, true);
|
|
||||||
float missing_right_loss = device_calc_loss_chg(
|
|
||||||
param, scan, missing, parent_sum, parent_gain, false);
|
|
||||||
if (missing_left_loss >= missing_right_loss) {
|
|
||||||
missing_left_out = true;
|
|
||||||
return missing_left_loss;
|
|
||||||
} else {
|
|
||||||
missing_left_out = false;
|
|
||||||
return missing_right_loss;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace exact
|
|
||||||
} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
@ -17,7 +17,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "../common.cuh"
|
#include "../common.cuh"
|
||||||
#include "gradients.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -67,7 +66,7 @@ template <typename node_id_t>
|
|||||||
class Node {
|
class Node {
|
||||||
public:
|
public:
|
||||||
/** sum of gradients across all training samples part of this node */
|
/** sum of gradients across all training samples part of this node */
|
||||||
gpu_gpair gradSum;
|
bst_gpair gradSum;
|
||||||
/** the optimal score for this node */
|
/** the optimal score for this node */
|
||||||
float score;
|
float score;
|
||||||
/** weightage for this node */
|
/** weightage for this node */
|
||||||
|
|||||||
@ -16,8 +16,6 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "../../../../src/tree/param.h"
|
#include "../../../../src/tree/param.h"
|
||||||
#include "gradients.cuh"
|
|
||||||
#include "loss_functions.cuh"
|
|
||||||
#include "node.cuh"
|
#include "node.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -37,11 +35,11 @@ namespace exact {
|
|||||||
*/
|
*/
|
||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
|
DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
|
||||||
const gpu_gpair& grad,
|
const bst_gpair& grad,
|
||||||
const TrainParam& param) {
|
const TrainParam& param) {
|
||||||
nodes[nid].gradSum = grad;
|
nodes[nid].gradSum = grad;
|
||||||
nodes[nid].score = CalcGain(param, grad.g, grad.h);
|
nodes[nid].score = CalcGain(param, grad.grad, grad.hess);
|
||||||
nodes[nid].weight = CalcWeight(param, grad.g, grad.h);
|
nodes[nid].weight = CalcWeight(param, grad.grad, grad.hess);
|
||||||
nodes[nid].id = nid;
|
nodes[nid].id = nid;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,7 +54,7 @@ DEV_INLINE void updateOneChildNode(Node<node_id_t>* nodes, int nid,
|
|||||||
*/
|
*/
|
||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
DEV_INLINE void updateChildNodes(Node<node_id_t>* nodes, int pid,
|
DEV_INLINE void updateChildNodes(Node<node_id_t>* nodes, int pid,
|
||||||
const gpu_gpair& gradL, const gpu_gpair& gradR,
|
const bst_gpair& gradL, const bst_gpair& gradR,
|
||||||
const TrainParam& param) {
|
const TrainParam& param) {
|
||||||
int childId = (pid * 2) + 1;
|
int childId = (pid * 2) + 1;
|
||||||
updateOneChildNode(nodes, childId, gradL, param);
|
updateOneChildNode(nodes, childId, gradL, param);
|
||||||
@ -66,15 +64,15 @@ DEV_INLINE void updateChildNodes(Node<node_id_t>* nodes, int pid,
|
|||||||
template <typename node_id_t>
|
template <typename node_id_t>
|
||||||
DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
|
DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
|
||||||
const Node<node_id_t>& n, int absNodeId,
|
const Node<node_id_t>& n, int absNodeId,
|
||||||
int colId, const gpu_gpair& gradScan,
|
int colId, const bst_gpair& gradScan,
|
||||||
const gpu_gpair& colSum, float thresh,
|
const bst_gpair& colSum, float thresh,
|
||||||
const TrainParam& param) {
|
const TrainParam& param) {
|
||||||
bool missingLeft = true;
|
bool missingLeft = true;
|
||||||
// get the default direction for the current node
|
// get the default direction for the current node
|
||||||
gpu_gpair missing = n.gradSum - colSum;
|
bst_gpair missing = n.gradSum - colSum;
|
||||||
loss_chg_missing(gradScan, missing, n.gradSum, n.score, param, missingLeft);
|
loss_chg_missing(gradScan, missing, n.gradSum, n.score, param, missingLeft);
|
||||||
// get the score/weight/id/gradSum for left and right child nodes
|
// get the score/weight/id/gradSum for left and right child nodes
|
||||||
gpu_gpair lGradSum, rGradSum;
|
bst_gpair lGradSum, rGradSum;
|
||||||
if (missingLeft) {
|
if (missingLeft) {
|
||||||
lGradSum = gradScan + n.gradSum - colSum;
|
lGradSum = gradScan + n.gradSum - colSum;
|
||||||
} else {
|
} else {
|
||||||
@ -90,8 +88,8 @@ DEV_INLINE void updateNodeAndChildren(Node<node_id_t>* nodes, const Split& s,
|
|||||||
|
|
||||||
template <typename node_id_t, int BLKDIM = 256>
|
template <typename node_id_t, int BLKDIM = 256>
|
||||||
__global__ void split2nodeKernel(
|
__global__ void split2nodeKernel(
|
||||||
Node<node_id_t>* nodes, const Split* nodeSplits, const gpu_gpair* gradScans,
|
Node<node_id_t>* nodes, const Split* nodeSplits, const bst_gpair* gradScans,
|
||||||
const gpu_gpair* gradSums, const float* vals, const int* colIds,
|
const bst_gpair* gradSums, const float* vals, const int* colIds,
|
||||||
const int* colOffsets, const node_id_t* nodeAssigns, int nUniqKeys,
|
const int* colOffsets, const node_id_t* nodeAssigns, int nUniqKeys,
|
||||||
node_id_t nodeStart, int nCols, const TrainParam param) {
|
node_id_t nodeStart, int nCols, const TrainParam param) {
|
||||||
int uid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
int uid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||||
@ -132,7 +130,7 @@ __global__ void split2nodeKernel(
|
|||||||
*/
|
*/
|
||||||
template <typename node_id_t, int BLKDIM = 256>
|
template <typename node_id_t, int BLKDIM = 256>
|
||||||
void split2node(Node<node_id_t>* nodes, const Split* nodeSplits,
|
void split2node(Node<node_id_t>* nodes, const Split* nodeSplits,
|
||||||
const gpu_gpair* gradScans, const gpu_gpair* gradSums,
|
const bst_gpair* gradScans, const bst_gpair* gradSums,
|
||||||
const float* vals, const int* colIds, const int* colOffsets,
|
const float* vals, const int* colIds, const int* colOffsets,
|
||||||
const node_id_t* nodeAssigns, int nUniqKeys,
|
const node_id_t* nodeAssigns, int nUniqKeys,
|
||||||
node_id_t nodeStart, int nCols, const TrainParam param) {
|
node_id_t nodeStart, int nCols, const TrainParam param) {
|
||||||
|
|||||||
@ -1,11 +0,0 @@
|
|||||||
/*!
|
|
||||||
* Copyright 2016 Rory mitchell
|
|
||||||
*/
|
|
||||||
#pragma once
|
|
||||||
#include "../../../src/common/random.h"
|
|
||||||
#include "../../../src/tree/param.h"
|
|
||||||
#include "types.cuh"
|
|
||||||
|
|
||||||
namespace xgboost {
|
|
||||||
namespace tree {} // namespace tree
|
|
||||||
} // namespace xgboost
|
|
||||||
@ -40,10 +40,10 @@ void DeviceHist::Init(int n_bins_in) {
|
|||||||
|
|
||||||
void DeviceHist::Reset(int device_idx) {
|
void DeviceHist::Reset(int device_idx) {
|
||||||
cudaSetDevice(device_idx);
|
cudaSetDevice(device_idx);
|
||||||
data.fill(gpu_gpair());
|
data.fill(bst_gpair());
|
||||||
}
|
}
|
||||||
|
|
||||||
gpu_gpair* DeviceHist::GetLevelPtr(int depth) {
|
bst_gpair* DeviceHist::GetLevelPtr(int depth) {
|
||||||
return data.data() + n_nodes(depth - 1) * n_bins;
|
return data.data() + n_nodes(depth - 1) * n_bins;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -53,20 +53,20 @@ HistBuilder DeviceHist::GetBuilder() {
|
|||||||
return HistBuilder(data.data(), n_bins);
|
return HistBuilder(data.data(), n_bins);
|
||||||
}
|
}
|
||||||
|
|
||||||
HistBuilder::HistBuilder(gpu_gpair* ptr, int n_bins)
|
HistBuilder::HistBuilder(bst_gpair* ptr, int n_bins)
|
||||||
: d_hist(ptr), n_bins(n_bins) {}
|
: d_hist(ptr), n_bins(n_bins) {}
|
||||||
|
|
||||||
__device__ void HistBuilder::Add(gpu_gpair gpair, int gidx, int nidx) const {
|
__device__ void HistBuilder::Add(bst_gpair gpair, int gidx, int nidx) const {
|
||||||
int hist_idx = nidx * n_bins + gidx;
|
int hist_idx = nidx * n_bins + gidx;
|
||||||
atomicAdd(&(d_hist[hist_idx]._grad), gpair._grad); // OPTMARK: This and below
|
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
||||||
// line lead to about 3X
|
// line lead to about 3X
|
||||||
// slowdown due to memory
|
// slowdown due to memory
|
||||||
// dependency and access
|
// dependency and access
|
||||||
// pattern issues.
|
// pattern issues.
|
||||||
atomicAdd(&(d_hist[hist_idx]._hess), gpair._hess);
|
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ gpu_gpair HistBuilder::Get(int gidx, int nidx) const {
|
__device__ bst_gpair HistBuilder::Get(int gidx, int nidx) const {
|
||||||
return d_hist[nidx * n_bins + gidx];
|
return d_hist[nidx * n_bins + gidx];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -362,7 +362,7 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
if (!is_smallest && depth > 0) return;
|
if (!is_smallest && depth > 0) return;
|
||||||
|
|
||||||
int gidx = d_gidx[local_idx];
|
int gidx = d_gidx[local_idx];
|
||||||
gpu_gpair gpair = d_gpair[ridx - row_begin];
|
bst_gpair gpair = d_gpair[ridx - row_begin];
|
||||||
|
|
||||||
hist_builder.Add(gpair, gidx, nidx); // OPTMARK: This is slow, could use
|
hist_builder.Add(gpair, gidx, nidx); // OPTMARK: This is slow, could use
|
||||||
// shared memory or cache results
|
// shared memory or cache results
|
||||||
@ -382,14 +382,14 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
// TODO(JCM): use out of place with pre-allocated buffer, but then have to
|
// TODO(JCM): use out of place with pre-allocated buffer, but then have to
|
||||||
// copy
|
// copy
|
||||||
// back on device
|
// back on device
|
||||||
// fprintf(stderr,"sizeof(gpu_gpair)/sizeof(float)=%d\n",sizeof(gpu_gpair)/sizeof(float));
|
// fprintf(stderr,"sizeof(bst_gpair)/sizeof(float)=%d\n",sizeof(bst_gpair)/sizeof(float));
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
dh::safe_nccl(ncclAllReduce(
|
dh::safe_nccl(ncclAllReduce(
|
||||||
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
hist_vec[d_idx].LevelSize(depth) * sizeof(gpu_gpair) / sizeof(float),
|
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) / sizeof(float),
|
||||||
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx])));
|
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx])));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -423,9 +423,9 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int gidx = idx % hist_builder.n_bins;
|
int gidx = idx % hist_builder.n_bins;
|
||||||
gpu_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||||
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
||||||
gpu_gpair other = hist_builder.Get(gidx, other_nidx);
|
bst_gpair other = hist_builder.Get(gidx, other_nidx);
|
||||||
hist_builder.Add(parent - other, gidx,
|
hist_builder.Add(parent - other, gidx,
|
||||||
nidx); // OPTMARK: This is slow, could use shared
|
nidx); // OPTMARK: This is slow, could use shared
|
||||||
// memory or cache results intead of writing to
|
// memory or cache results intead of writing to
|
||||||
@ -438,16 +438,16 @@ void GPUHistBuilder::BuildHist(int depth) {
|
|||||||
|
|
||||||
template <int BLOCK_THREADS>
|
template <int BLOCK_THREADS>
|
||||||
__global__ void find_split_kernel(
|
__global__ void find_split_kernel(
|
||||||
const gpu_gpair* d_level_hist, int* d_feature_segments, int depth,
|
const bst_gpair* d_level_hist, int* d_feature_segments, int depth,
|
||||||
int n_features, int n_bins, Node* d_nodes, Node* d_nodes_temp,
|
int n_features, int n_bins, Node* d_nodes, Node* d_nodes_temp,
|
||||||
Node* d_nodes_child_temp, int nodes_offset_device, float* d_fidx_min_map,
|
Node* d_nodes_child_temp, int nodes_offset_device, float* d_fidx_min_map,
|
||||||
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
float* d_gidx_fvalue_map, GPUTrainingParam gpu_param,
|
||||||
bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) {
|
bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) {
|
||||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||||
typedef cub::BlockScan<gpu_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
typedef cub::BlockScan<bst_gpair, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||||
BlockScanT;
|
BlockScanT;
|
||||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||||
typedef cub::BlockReduce<gpu_gpair, BLOCK_THREADS> SumReduceT;
|
typedef cub::BlockReduce<bst_gpair, BLOCK_THREADS> SumReduceT;
|
||||||
|
|
||||||
union TempStorage {
|
union TempStorage {
|
||||||
typename BlockScanT::TempStorage scan;
|
typename BlockScanT::TempStorage scan;
|
||||||
@ -456,12 +456,12 @@ __global__ void find_split_kernel(
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
struct UninitializedSplit : cub::Uninitialized<Split> {};
|
||||||
struct UninitializedGpair : cub::Uninitialized<gpu_gpair> {};
|
struct UninitializedGpair : cub::Uninitialized<bst_gpair> {};
|
||||||
|
|
||||||
__shared__ UninitializedSplit uninitialized_split;
|
__shared__ UninitializedSplit uninitialized_split;
|
||||||
Split& split = uninitialized_split.Alias();
|
Split& split = uninitialized_split.Alias();
|
||||||
__shared__ UninitializedGpair uninitialized_sum;
|
__shared__ UninitializedGpair uninitialized_sum;
|
||||||
gpu_gpair& shared_sum = uninitialized_sum.Alias();
|
bst_gpair& shared_sum = uninitialized_sum.Alias();
|
||||||
__shared__ ArgMaxT block_max;
|
__shared__ ArgMaxT block_max;
|
||||||
__shared__ TempStorage temp_storage;
|
__shared__ TempStorage temp_storage;
|
||||||
|
|
||||||
@ -484,12 +484,12 @@ __global__ void find_split_kernel(
|
|||||||
int gidx = (begin - (level_node_idx * n_bins)) + threadIdx.x;
|
int gidx = (begin - (level_node_idx * n_bins)) + threadIdx.x;
|
||||||
bool thread_active = threadIdx.x < end - begin;
|
bool thread_active = threadIdx.x < end - begin;
|
||||||
|
|
||||||
gpu_gpair feature_sum = gpu_gpair();
|
bst_gpair feature_sum = bst_gpair();
|
||||||
for (int reduce_begin = begin; reduce_begin < end;
|
for (int reduce_begin = begin; reduce_begin < end;
|
||||||
reduce_begin += BLOCK_THREADS) {
|
reduce_begin += BLOCK_THREADS) {
|
||||||
// Scan histogram
|
// Scan histogram
|
||||||
gpu_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
bst_gpair bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
||||||
: gpu_gpair();
|
: bst_gpair();
|
||||||
|
|
||||||
feature_sum +=
|
feature_sum +=
|
||||||
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
||||||
@ -503,17 +503,17 @@ __global__ void find_split_kernel(
|
|||||||
GpairCallbackOp prefix_op = GpairCallbackOp();
|
GpairCallbackOp prefix_op = GpairCallbackOp();
|
||||||
for (int scan_begin = begin; scan_begin < end;
|
for (int scan_begin = begin; scan_begin < end;
|
||||||
scan_begin += BLOCK_THREADS) {
|
scan_begin += BLOCK_THREADS) {
|
||||||
gpu_gpair bin =
|
bst_gpair bin =
|
||||||
thread_active ? d_level_hist[scan_begin + threadIdx.x] : gpu_gpair();
|
thread_active ? d_level_hist[scan_begin + threadIdx.x] : bst_gpair();
|
||||||
|
|
||||||
BlockScanT(temp_storage.scan)
|
BlockScanT(temp_storage.scan)
|
||||||
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||||
|
|
||||||
// Calculate gain
|
// Calculate gain
|
||||||
gpu_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
bst_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
||||||
float parent_gain = d_nodes[node_idx].root_gain;
|
float parent_gain = d_nodes[node_idx].root_gain;
|
||||||
|
|
||||||
gpu_gpair missing = parent_sum - shared_sum;
|
bst_gpair missing = parent_sum - shared_sum;
|
||||||
|
|
||||||
bool missing_left;
|
bool missing_left;
|
||||||
float gain = thread_active
|
float gain = thread_active
|
||||||
@ -543,8 +543,8 @@ __global__ void find_split_kernel(
|
|||||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
gpu_gpair left = missing_left ? bin + missing : bin;
|
bst_gpair left = missing_left ? bin + missing : bin;
|
||||||
gpu_gpair right = parent_sum - left;
|
bst_gpair right = parent_sum - left;
|
||||||
|
|
||||||
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||||
}
|
}
|
||||||
@ -581,16 +581,16 @@ __global__ void find_split_kernel(
|
|||||||
|
|
||||||
*Nodeleft = Node(
|
*Nodeleft = Node(
|
||||||
split.left_sum,
|
split.left_sum,
|
||||||
CalcGain(gpu_param, split.left_sum.grad(), split.left_sum.hess()),
|
CalcGain(gpu_param, split.left_sum.grad, split.left_sum.hess),
|
||||||
CalcWeight(gpu_param, split.left_sum.grad(), split.left_sum.hess()));
|
CalcWeight(gpu_param, split.left_sum.grad, split.left_sum.hess));
|
||||||
|
|
||||||
*Noderight = Node(
|
*Noderight = Node(
|
||||||
split.right_sum,
|
split.right_sum,
|
||||||
CalcGain(gpu_param, split.right_sum.grad(), split.right_sum.hess()),
|
CalcGain(gpu_param, split.right_sum.grad, split.right_sum.hess),
|
||||||
CalcWeight(gpu_param, split.right_sum.grad(), split.right_sum.hess()));
|
CalcWeight(gpu_param, split.right_sum.grad, split.right_sum.hess));
|
||||||
|
|
||||||
// Record smallest node
|
// Record smallest node
|
||||||
if (split.left_sum.hess() <= split.right_sum.hess()) {
|
if (split.left_sum.hess <= split.right_sum.hess) {
|
||||||
*left_child_smallest = true;
|
*left_child_smallest = true;
|
||||||
} else {
|
} else {
|
||||||
*left_child_smallest = false;
|
*left_child_smallest = false;
|
||||||
@ -654,11 +654,11 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int nodes_offset_device = d_idx * num_nodes_device;
|
int nodes_offset_device = d_idx * num_nodes_device;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const gpu_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_temp[d_idx].data(),
|
||||||
nodes_child_temp[d_idx].data(), nodes_offset_device,
|
nodes_child_temp[d_idx].data(), nodes_offset_device,
|
||||||
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), gpu_param,
|
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||||
left_child_smallest_temp[d_idx].data(), colsample,
|
left_child_smallest_temp[d_idx].data(), colsample,
|
||||||
feature_flags[d_idx].data());
|
feature_flags[d_idx].data());
|
||||||
}
|
}
|
||||||
@ -751,11 +751,11 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int nodes_offset_device = d_idx * num_nodes_device;
|
int nodes_offset_device = d_idx * num_nodes_device;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const gpu_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
||||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
gidx_fvalue_map[d_idx].data(), gpu_param,
|
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||||
left_child_smallest[d_idx].data(), colsample,
|
left_child_smallest[d_idx].data(), colsample,
|
||||||
feature_flags[d_idx].data());
|
feature_flags[d_idx].data());
|
||||||
|
|
||||||
@ -805,11 +805,11 @@ void GPUHistBuilder::LaunchFindSplit(int depth) {
|
|||||||
|
|
||||||
int nodes_offset_device = 0;
|
int nodes_offset_device = 0;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const gpu_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
feature_segments[d_idx].data(), depth, (info->num_col),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
(hmat_.row_ptr.back()), nodes[d_idx].data(), NULL, NULL,
|
||||||
nodes_offset_device, fidx_min_map[d_idx].data(),
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
gidx_fvalue_map[d_idx].data(), gpu_param,
|
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||||
left_child_smallest[d_idx].data(), colsample,
|
left_child_smallest[d_idx].data(), colsample,
|
||||||
feature_flags[d_idx].data());
|
feature_flags[d_idx].data());
|
||||||
}
|
}
|
||||||
@ -827,21 +827,21 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
|||||||
// and C:/Program Files (x86)/Microsoft Visual Studio
|
// and C:/Program Files (x86)/Microsoft Visual Studio
|
||||||
// 14.0/VC/bin/../../VC/INCLUDE\future(1888): error : no instance of function
|
// 14.0/VC/bin/../../VC/INCLUDE\future(1888): error : no instance of function
|
||||||
// template "std::_Invoke_stored" matches the argument list
|
// template "std::_Invoke_stored" matches the argument list
|
||||||
std::vector<gpu_gpair> future_results(n_devices);
|
std::vector<bst_gpair> future_results(n_devices);
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
|
|
||||||
auto begin = device_gpair[d_idx].tbegin();
|
auto begin = device_gpair[d_idx].tbegin();
|
||||||
auto end = device_gpair[d_idx].tend();
|
auto end = device_gpair[d_idx].tend();
|
||||||
gpu_gpair init = gpu_gpair();
|
bst_gpair init = bst_gpair();
|
||||||
auto binary_op = thrust::plus<gpu_gpair>();
|
auto binary_op = thrust::plus<bst_gpair>();
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
future_results[d_idx] = thrust::reduce(begin, end, init, binary_op);
|
future_results[d_idx] = thrust::reduce(begin, end, init, binary_op);
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum over devices on host (with blocking get())
|
// sum over devices on host (with blocking get())
|
||||||
gpu_gpair sum = gpu_gpair();
|
bst_gpair sum = bst_gpair();
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
sum += future_results[d_idx];
|
sum += future_results[d_idx];
|
||||||
@ -849,7 +849,7 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
|||||||
#else
|
#else
|
||||||
// asynch reduce per device
|
// asynch reduce per device
|
||||||
|
|
||||||
std::vector<std::future<gpu_gpair>> future_results(n_devices);
|
std::vector<std::future<bst_gpair>> future_results(n_devices);
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
// std::async captures the algorithm parameters by value
|
// std::async captures the algorithm parameters by value
|
||||||
// use std::launch::async to ensure the creation of a new thread
|
// use std::launch::async to ensure the creation of a new thread
|
||||||
@ -858,14 +858,14 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
|||||||
dh::safe_cuda(cudaSetDevice(device_idx));
|
dh::safe_cuda(cudaSetDevice(device_idx));
|
||||||
auto begin = device_gpair[d_idx].tbegin();
|
auto begin = device_gpair[d_idx].tbegin();
|
||||||
auto end = device_gpair[d_idx].tend();
|
auto end = device_gpair[d_idx].tend();
|
||||||
gpu_gpair init = gpu_gpair();
|
bst_gpair init = bst_gpair();
|
||||||
auto binary_op = thrust::plus<gpu_gpair>();
|
auto binary_op = thrust::plus<bst_gpair>();
|
||||||
return thrust::reduce(begin, end, init, binary_op);
|
return thrust::reduce(begin, end, init, binary_op);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum over devices on host (with blocking get())
|
// sum over devices on host (with blocking get())
|
||||||
gpu_gpair sum = gpu_gpair();
|
bst_gpair sum = bst_gpair();
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
sum += future_results[d_idx].get();
|
sum += future_results[d_idx].get();
|
||||||
@ -879,15 +879,15 @@ void GPUHistBuilder::InitFirstNode(const std::vector<bst_gpair>& gpair) {
|
|||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
|
|
||||||
auto d_nodes = nodes[d_idx].data();
|
auto d_nodes = nodes[d_idx].data();
|
||||||
auto gpu_param_alias = gpu_param;
|
auto gpu_param = GPUTrainingParam(param);
|
||||||
|
|
||||||
dh::launch_n(device_idx, 1, [=] __device__(int idx) {
|
dh::launch_n(device_idx, 1, [=] __device__(int idx) {
|
||||||
gpu_gpair sum_gradients = sum;
|
bst_gpair sum_gradients = sum;
|
||||||
d_nodes[idx] = Node(
|
d_nodes[idx] = Node(
|
||||||
sum_gradients,
|
sum_gradients,
|
||||||
CalcGain(gpu_param_alias, sum_gradients.grad(), sum_gradients.hess()),
|
CalcGain(gpu_param, sum_gradients.grad, sum_gradients.hess),
|
||||||
CalcWeight(gpu_param_alias, sum_gradients.grad(),
|
CalcWeight(gpu_param, sum_gradients.grad,
|
||||||
sum_gradients.hess()));
|
sum_gradients.hess));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
// synch all devices to host before moving on (No, can avoid because BuildHist
|
// synch all devices to host before moving on (No, can avoid because BuildHist
|
||||||
@ -916,7 +916,7 @@ void GPUHistBuilder::UpdatePositionDense(int depth) {
|
|||||||
size_t end = device_row_segments[d_idx + 1];
|
size_t end = device_row_segments[d_idx + 1];
|
||||||
|
|
||||||
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) {
|
dh::launch_n(device_idx, end - begin, [=] __device__(int local_idx) {
|
||||||
NodeIdT pos = d_position[local_idx];
|
int pos = d_position[local_idx];
|
||||||
if (!is_active(pos, depth)) {
|
if (!is_active(pos, depth)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -961,7 +961,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
|||||||
// Update missing direction
|
// Update missing direction
|
||||||
dh::launch_n(device_idx, row_end - row_begin,
|
dh::launch_n(device_idx, row_end - row_begin,
|
||||||
[=] __device__(int local_idx) {
|
[=] __device__(int local_idx) {
|
||||||
NodeIdT pos = d_position[local_idx];
|
int pos = d_position[local_idx];
|
||||||
if (!is_active(pos, depth)) {
|
if (!is_active(pos, depth)) {
|
||||||
d_position_tmp[local_idx] = pos;
|
d_position_tmp[local_idx] = pos;
|
||||||
return;
|
return;
|
||||||
@ -985,7 +985,7 @@ void GPUHistBuilder::UpdatePositionSparse(int depth) {
|
|||||||
dh::launch_n(
|
dh::launch_n(
|
||||||
device_idx, element_end - element_begin, [=] __device__(int local_idx) {
|
device_idx, element_end - element_begin, [=] __device__(int local_idx) {
|
||||||
int ridx = d_ridx[local_idx];
|
int ridx = d_ridx[local_idx];
|
||||||
NodeIdT pos = d_position[ridx - row_begin];
|
int pos = d_position[ridx - row_begin];
|
||||||
if (!is_active(pos, depth)) {
|
if (!is_active(pos, depth)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,16 +31,16 @@ struct DeviceGMat {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct HistBuilder {
|
struct HistBuilder {
|
||||||
gpu_gpair *d_hist;
|
bst_gpair *d_hist;
|
||||||
int n_bins;
|
int n_bins;
|
||||||
__host__ __device__ HistBuilder(gpu_gpair *ptr, int n_bins);
|
__host__ __device__ HistBuilder(bst_gpair *ptr, int n_bins);
|
||||||
__device__ void Add(gpu_gpair gpair, int gidx, int nidx) const;
|
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const;
|
||||||
__device__ gpu_gpair Get(int gidx, int nidx) const;
|
__device__ bst_gpair Get(int gidx, int nidx) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceHist {
|
struct DeviceHist {
|
||||||
int n_bins;
|
int n_bins;
|
||||||
dh::dvec<gpu_gpair> data;
|
dh::dvec<bst_gpair> data;
|
||||||
|
|
||||||
void Init(int max_depth);
|
void Init(int max_depth);
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ struct DeviceHist {
|
|||||||
|
|
||||||
HistBuilder GetBuilder();
|
HistBuilder GetBuilder();
|
||||||
|
|
||||||
gpu_gpair *GetLevelPtr(int depth);
|
bst_gpair *GetLevelPtr(int depth);
|
||||||
|
|
||||||
int LevelSize(int depth);
|
int LevelSize(int depth);
|
||||||
};
|
};
|
||||||
@ -61,8 +61,6 @@ class GPUHistBuilder {
|
|||||||
|
|
||||||
void UpdateParam(const TrainParam ¶m) {
|
void UpdateParam(const TrainParam ¶m) {
|
||||||
this->param = param;
|
this->param = param;
|
||||||
this->gpu_param = GPUTrainingParam(param.min_child_weight, param.reg_lambda,
|
|
||||||
param.reg_alpha, param.max_delta_step);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
|
void InitData(const std::vector<bst_gpair> &gpair, DMatrix &fmat, // NOLINT
|
||||||
@ -85,7 +83,6 @@ class GPUHistBuilder {
|
|||||||
std::vector<bst_float> *p_out_preds);
|
std::vector<bst_float> *p_out_preds);
|
||||||
|
|
||||||
TrainParam param;
|
TrainParam param;
|
||||||
GPUTrainingParam gpu_param;
|
|
||||||
common::HistCutMatrix hmat_;
|
common::HistCutMatrix hmat_;
|
||||||
common::GHistIndexMatrix gmat_;
|
common::GHistIndexMatrix gmat_;
|
||||||
MetaInfo *info;
|
MetaInfo *info;
|
||||||
@ -124,7 +121,7 @@ class GPUHistBuilder {
|
|||||||
std::vector<dh::dvec<int>> position;
|
std::vector<dh::dvec<int>> position;
|
||||||
std::vector<dh::dvec<int>> position_tmp;
|
std::vector<dh::dvec<int>> position_tmp;
|
||||||
std::vector<DeviceGMat> device_matrix;
|
std::vector<DeviceGMat> device_matrix;
|
||||||
std::vector<dh::dvec<gpu_gpair>> device_gpair;
|
std::vector<dh::dvec<bst_gpair>> device_gpair;
|
||||||
std::vector<dh::dvec<int>> gidx_feature_map;
|
std::vector<dh::dvec<int>> gidx_feature_map;
|
||||||
std::vector<dh::dvec<float>> gidx_fvalue_map;
|
std::vector<dh::dvec<float>> gidx_fvalue_map;
|
||||||
|
|
||||||
|
|||||||
@ -11,85 +11,6 @@
|
|||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
typedef int16_t NodeIdT;
|
|
||||||
|
|
||||||
// gpair type defined with device accessible functions
|
|
||||||
struct gpu_gpair {
|
|
||||||
float _grad;
|
|
||||||
float _hess;
|
|
||||||
|
|
||||||
__host__ __device__ __forceinline__ float grad() const { return _grad; }
|
|
||||||
|
|
||||||
__host__ __device__ __forceinline__ float hess() const { return _hess; }
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair() : _grad(0), _hess(0) {}
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair(float g, float h) : _grad(g), _hess(h) {}
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair(bst_gpair gpair)
|
|
||||||
: _grad(gpair.grad), _hess(gpair.hess) {}
|
|
||||||
|
|
||||||
__host__ __device__ bool operator==(const gpu_gpair &rhs) const {
|
|
||||||
return (_grad == rhs._grad) && (_hess == rhs._hess);
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ bool operator!=(const gpu_gpair &rhs) const {
|
|
||||||
return !(*this == rhs);
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair &operator+=(const gpu_gpair &rhs) {
|
|
||||||
_grad += rhs._grad;
|
|
||||||
_hess += rhs._hess;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair operator+(const gpu_gpair &rhs) const {
|
|
||||||
gpu_gpair g;
|
|
||||||
g._grad = _grad + rhs._grad;
|
|
||||||
g._hess = _hess + rhs._hess;
|
|
||||||
return g;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair &operator-=(const gpu_gpair &rhs) {
|
|
||||||
_grad -= rhs._grad;
|
|
||||||
_hess -= rhs._hess;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ gpu_gpair operator-(const gpu_gpair &rhs) const {
|
|
||||||
gpu_gpair g;
|
|
||||||
g._grad = _grad - rhs._grad;
|
|
||||||
g._hess = _hess - rhs._hess;
|
|
||||||
return g;
|
|
||||||
}
|
|
||||||
|
|
||||||
friend std::ostream &operator<<(std::ostream &os, const gpu_gpair &g) {
|
|
||||||
os << g.grad() << "/" << g.hess();
|
|
||||||
return os;
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ void print() const {
|
|
||||||
printf("%1.4f/%1.4f\n", grad(), hess());
|
|
||||||
}
|
|
||||||
|
|
||||||
__host__ __device__ bool approximate_compare(const gpu_gpair &b,
|
|
||||||
float g_eps = 0.1,
|
|
||||||
float h_eps = 0.1) const {
|
|
||||||
float gdiff = abs(this->grad() - b.grad());
|
|
||||||
float hdiff = abs(this->hess() - b.hess());
|
|
||||||
|
|
||||||
return (gdiff <= g_eps) && (hdiff <= h_eps);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef thrust::device_vector<bst_uint>::iterator uint_iter;
|
|
||||||
typedef thrust::device_vector<gpu_gpair>::iterator gpair_iter;
|
|
||||||
typedef thrust::device_vector<float>::iterator float_iter;
|
|
||||||
typedef thrust::device_vector<NodeIdT>::iterator node_id_iter;
|
|
||||||
typedef thrust::permutation_iterator<gpair_iter, uint_iter> gpair_perm_iter;
|
|
||||||
typedef thrust::tuple<gpair_perm_iter, float_iter, node_id_iter> ItemTuple;
|
|
||||||
typedef thrust::zip_iterator<ItemTuple> ItemIter;
|
|
||||||
|
|
||||||
struct GPUTrainingParam {
|
struct GPUTrainingParam {
|
||||||
// minimum amount of hessian(weight) allowed in a child
|
// minimum amount of hessian(weight) allowed in a child
|
||||||
float min_child_weight;
|
float min_child_weight;
|
||||||
@ -104,6 +25,12 @@ struct GPUTrainingParam {
|
|||||||
|
|
||||||
__host__ __device__ GPUTrainingParam() {}
|
__host__ __device__ GPUTrainingParam() {}
|
||||||
|
|
||||||
|
__host__ __device__ GPUTrainingParam(const TrainParam ¶m)
|
||||||
|
: min_child_weight(param.min_child_weight),
|
||||||
|
reg_lambda(param.reg_lambda),
|
||||||
|
reg_alpha(param.reg_alpha),
|
||||||
|
max_delta_step(param.max_delta_step) {}
|
||||||
|
|
||||||
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
|
__host__ __device__ GPUTrainingParam(float min_child_weight_in,
|
||||||
float reg_lambda_in, float reg_alpha_in,
|
float reg_lambda_in, float reg_alpha_in,
|
||||||
float max_delta_step_in)
|
float max_delta_step_in)
|
||||||
@ -118,19 +45,19 @@ struct Split {
|
|||||||
bool missing_left;
|
bool missing_left;
|
||||||
float fvalue;
|
float fvalue;
|
||||||
int findex;
|
int findex;
|
||||||
gpu_gpair left_sum;
|
bst_gpair left_sum;
|
||||||
gpu_gpair right_sum;
|
bst_gpair right_sum;
|
||||||
|
|
||||||
__host__ __device__ Split()
|
__host__ __device__ Split()
|
||||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
||||||
|
|
||||||
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
||||||
float fvalue_in, int findex_in, gpu_gpair left_sum_in,
|
float fvalue_in, int findex_in, bst_gpair left_sum_in,
|
||||||
gpu_gpair right_sum_in,
|
bst_gpair right_sum_in,
|
||||||
const GPUTrainingParam ¶m) {
|
const GPUTrainingParam ¶m) {
|
||||||
if (loss_chg_in > loss_chg &&
|
if (loss_chg_in > loss_chg &&
|
||||||
left_sum_in.hess() >= param.min_child_weight &&
|
left_sum_in.hess>= param.min_child_weight &&
|
||||||
right_sum_in.hess() >= param.min_child_weight) {
|
right_sum_in.hess>= param.min_child_weight) {
|
||||||
loss_chg = loss_chg_in;
|
loss_chg = loss_chg_in;
|
||||||
missing_left = missing_left_in;
|
missing_left = missing_left_in;
|
||||||
fvalue = fvalue_in;
|
fvalue = fvalue_in;
|
||||||
@ -152,16 +79,16 @@ struct Split {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__host__ __device__ void Print() {
|
//__host__ __device__ void Print() {
|
||||||
printf("Loss: %1.4f\n", loss_chg);
|
// printf("Loss: %1.4f\n", loss_chg);
|
||||||
printf("Missing left: %d\n", missing_left);
|
// printf("Missing left: %d\n", missing_left);
|
||||||
printf("fvalue: %1.4f\n", fvalue);
|
// printf("fvalue: %1.4f\n", fvalue);
|
||||||
printf("Left sum: ");
|
// printf("Left sum: ");
|
||||||
left_sum.print();
|
// left_sum.print();
|
||||||
|
|
||||||
printf("Right sum: ");
|
// printf("Right sum: ");
|
||||||
right_sum.print();
|
// right_sum.print();
|
||||||
}
|
//}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct split_reduce_op {
|
struct split_reduce_op {
|
||||||
@ -173,7 +100,7 @@ struct split_reduce_op {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct Node {
|
struct Node {
|
||||||
gpu_gpair sum_gradients;
|
bst_gpair sum_gradients;
|
||||||
float root_gain;
|
float root_gain;
|
||||||
float weight;
|
float weight;
|
||||||
|
|
||||||
@ -181,7 +108,7 @@ struct Node {
|
|||||||
|
|
||||||
__host__ __device__ Node() : weight(0), root_gain(0) {}
|
__host__ __device__ Node() : weight(0), root_gain(0) {}
|
||||||
|
|
||||||
__host__ __device__ Node(gpu_gpair sum_gradients_in, float root_gain_in,
|
__host__ __device__ Node(bst_gpair sum_gradients_in, float root_gain_in,
|
||||||
float weight_in) {
|
float weight_in) {
|
||||||
sum_gradients = sum_gradients_in;
|
sum_gradients = sum_gradients_in;
|
||||||
root_gain = root_gain_in;
|
root_gain = root_gain_in;
|
||||||
|
|||||||
@ -14,11 +14,6 @@
|
|||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#ifdef __NVCC__
|
|
||||||
#define XGB_DEVICE __host__ __device__
|
|
||||||
#else
|
|
||||||
#define XGB_DEVICE
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
@ -234,7 +229,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
|
|||||||
|
|
||||||
// functions for L1 cost
|
// functions for L1 cost
|
||||||
template <typename T1, typename T2>
|
template <typename T1, typename T2>
|
||||||
XGB_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) {
|
XGBOOST_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) {
|
||||||
if (w > +lambda)
|
if (w > +lambda)
|
||||||
return w - lambda;
|
return w - lambda;
|
||||||
if (w < -lambda)
|
if (w < -lambda)
|
||||||
@ -243,18 +238,18 @@ XGB_DEVICE inline static T1 ThresholdL1(T1 w, T2 lambda) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
XGB_DEVICE inline static T Sqr(T a) { return a * a; }
|
XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
|
||||||
|
|
||||||
// calculate the cost of loss function
|
// calculate the cost of loss function
|
||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGB_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad,
|
XGBOOST_DEVICE inline T CalcGainGivenWeight(const TrainingParams &p, T sum_grad,
|
||||||
T sum_hess, T w) {
|
T sum_hess, T w) {
|
||||||
return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w));
|
return -(2.0 * sum_grad * w + (sum_hess + p.reg_lambda) * Sqr(w));
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculate the cost of loss function
|
// calculate the cost of loss function
|
||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
||||||
if (sum_hess < p.min_child_weight)
|
if (sum_hess < p.min_child_weight)
|
||||||
return 0.0;
|
return 0.0;
|
||||||
if (p.max_delta_step == 0.0f) {
|
if (p.max_delta_step == 0.0f) {
|
||||||
@ -276,7 +271,7 @@ XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
|
|||||||
}
|
}
|
||||||
// calculate cost of loss function with four statistics
|
// calculate cost of loss function with four statistics
|
||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
||||||
T test_grad, T test_hess) {
|
T test_grad, T test_hess) {
|
||||||
T w = CalcWeight(sum_grad, sum_hess);
|
T w = CalcWeight(sum_grad, sum_hess);
|
||||||
T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w);
|
T ret = test_grad * w + 0.5 * (test_hess + p.reg_lambda) * Sqr(w);
|
||||||
@ -288,7 +283,7 @@ XGB_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess,
|
|||||||
}
|
}
|
||||||
// calculate weight given the statistics
|
// calculate weight given the statistics
|
||||||
template <typename TrainingParams, typename T>
|
template <typename TrainingParams, typename T>
|
||||||
XGB_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
|
||||||
T sum_hess) {
|
T sum_hess) {
|
||||||
if (sum_hess < p.min_child_weight)
|
if (sum_hess < p.min_child_weight)
|
||||||
return 0.0;
|
return 0.0;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user