Integer gradient summation for GPU histogram algorithm. (#2681)
This commit is contained in:
parent
15267eedf2
commit
e6a9063344
@ -87,65 +87,116 @@ typedef uint64_t bst_ulong; // NOLINT(*)
|
|||||||
typedef float bst_float;
|
typedef float bst_float;
|
||||||
|
|
||||||
|
|
||||||
/*! \brief Implementation of gradient statistics pair */
|
namespace detail {
|
||||||
|
/*! \brief Implementation of gradient statistics pair. Template specialisation
|
||||||
|
* may be used to overload different gradients types e.g. low precision, high
|
||||||
|
* precision, integer, floating point. */
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct bst_gpair_internal {
|
class bst_gpair_internal {
|
||||||
/*! \brief gradient statistics */
|
/*! \brief gradient statistics */
|
||||||
T grad;
|
T grad_;
|
||||||
/*! \brief second order gradient statistics */
|
/*! \brief second order gradient statistics */
|
||||||
T hess;
|
T hess_;
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {}
|
XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; }
|
||||||
|
XGBOOST_DEVICE void SetHess(float h) { hess_ = h; }
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal(T grad, T hess)
|
public:
|
||||||
: grad(grad), hess(hess) {}
|
typedef T value_t;
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair_internal() : grad_(0), hess_(0) {}
|
||||||
|
|
||||||
|
XGBOOST_DEVICE bst_gpair_internal(float grad, float hess) {
|
||||||
|
SetGrad(grad);
|
||||||
|
SetHess(hess);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy constructor if of same value type
|
||||||
|
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T> &g)
|
||||||
|
: grad_(g.grad_), hess_(g.hess_) {}
|
||||||
|
|
||||||
|
// Copy constructor if different value type - use getters and setters to
|
||||||
|
// perform conversion
|
||||||
template <typename T2>
|
template <typename T2>
|
||||||
XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal<T2>&g)
|
XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal<T2> &g) {
|
||||||
: grad(g.grad), hess(g.hess) {}
|
SetGrad(g.GetGrad());
|
||||||
|
SetHess(g.GetHess());
|
||||||
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(const bst_gpair_internal<T> &rhs) {
|
XGBOOST_DEVICE float GetGrad() const { return grad_; }
|
||||||
grad += rhs.grad;
|
XGBOOST_DEVICE float GetHess() const { return hess_; }
|
||||||
hess += rhs.hess;
|
|
||||||
|
XGBOOST_DEVICE bst_gpair_internal<T> &operator+=(
|
||||||
|
const bst_gpair_internal<T> &rhs) {
|
||||||
|
grad_ += rhs.grad_;
|
||||||
|
hess_ += rhs.hess_;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal<T> operator+(const bst_gpair_internal<T> &rhs) const {
|
XGBOOST_DEVICE bst_gpair_internal<T> operator+(
|
||||||
|
const bst_gpair_internal<T> &rhs) const {
|
||||||
bst_gpair_internal<T> g;
|
bst_gpair_internal<T> g;
|
||||||
g.grad = grad + rhs.grad;
|
g.grad_ = grad_ + rhs.grad_;
|
||||||
g.hess = hess + rhs.hess;
|
g.hess_ = hess_ + rhs.hess_;
|
||||||
return g;
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(const bst_gpair_internal<T> &rhs) {
|
XGBOOST_DEVICE bst_gpair_internal<T> &operator-=(
|
||||||
grad -= rhs.grad;
|
const bst_gpair_internal<T> &rhs) {
|
||||||
hess -= rhs.hess;
|
grad_ -= rhs.grad_;
|
||||||
|
hess_ -= rhs.hess_;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal<T> operator-(const bst_gpair_internal<T> &rhs) const {
|
XGBOOST_DEVICE bst_gpair_internal<T> operator-(
|
||||||
|
const bst_gpair_internal<T> &rhs) const {
|
||||||
bst_gpair_internal<T> g;
|
bst_gpair_internal<T> g;
|
||||||
g.grad = grad - rhs.grad;
|
g.grad_ = grad_ - rhs.grad_;
|
||||||
g.hess = hess - rhs.hess;
|
g.hess_ = hess_ - rhs.hess_;
|
||||||
return g;
|
return g;
|
||||||
}
|
}
|
||||||
|
|
||||||
XGBOOST_DEVICE bst_gpair_internal(int value) {
|
XGBOOST_DEVICE bst_gpair_internal(int value) {
|
||||||
*this = bst_gpair_internal<T>(static_cast<float>(value), static_cast<float>(value));
|
*this = bst_gpair_internal<T>(static_cast<float>(value),
|
||||||
|
static_cast<float>(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
friend std::ostream &operator<<(std::ostream &os,
|
friend std::ostream &operator<<(std::ostream &os,
|
||||||
const bst_gpair_internal<T> &g) {
|
const bst_gpair_internal<T> &g) {
|
||||||
os << g.grad << "/" << g.hess;
|
os << g.grad_ << "/" << g.hess_;
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetGrad() const {
|
||||||
|
return grad_ * 1e-5;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
inline XGBOOST_DEVICE float bst_gpair_internal<int64_t>::GetHess() const {
|
||||||
|
return hess_ * 1e-5;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetGrad(float g) {
|
||||||
|
grad_ = g * 1e5;
|
||||||
|
}
|
||||||
|
template<>
|
||||||
|
inline XGBOOST_DEVICE void bst_gpair_internal<int64_t>::SetHess(float h) {
|
||||||
|
hess_ = h * 1e5;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
/*! \brief gradient statistics pair usually needed in gradient boosting */
|
/*! \brief gradient statistics pair usually needed in gradient boosting */
|
||||||
typedef bst_gpair_internal<float> bst_gpair;
|
typedef detail::bst_gpair_internal<float> bst_gpair;
|
||||||
|
|
||||||
/*! \brief High precision gradient statistics pair */
|
/*! \brief High precision gradient statistics pair */
|
||||||
typedef bst_gpair_internal<double> bst_gpair_precise;
|
typedef detail::bst_gpair_internal<double> bst_gpair_precise;
|
||||||
|
|
||||||
|
/*! \brief High precision gradient statistics pair with integer backed
|
||||||
|
* storage. Operators are associative where floating point versions are not
|
||||||
|
* associative. */
|
||||||
|
typedef detail::bst_gpair_internal<int64_t> bst_gpair_integer;
|
||||||
|
|
||||||
/*! \brief small eps gap for minimum split decision. */
|
/*! \brief small eps gap for minimum split decision. */
|
||||||
const bst_float rt_eps = 1e-6f;
|
const bst_float rt_eps = 1e-6f;
|
||||||
|
|||||||
@ -33,8 +33,8 @@ struct GHistEntry {
|
|||||||
|
|
||||||
/*! \brief add a bst_gpair to the sum */
|
/*! \brief add a bst_gpair to the sum */
|
||||||
inline void Add(const bst_gpair& e) {
|
inline void Add(const bst_gpair& e) {
|
||||||
sum_grad += e.grad;
|
sum_grad += e.GetGrad();
|
||||||
sum_hess += e.hess;
|
sum_hess += e.GetHess();
|
||||||
}
|
}
|
||||||
|
|
||||||
/*! \brief add a GHistEntry to the sum */
|
/*! \brief add a GHistEntry to the sum */
|
||||||
|
|||||||
@ -120,8 +120,9 @@ class GBLinear : public GradientBooster {
|
|||||||
#pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess)
|
#pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
||||||
if (p.hess >= 0.0f) {
|
if (p.GetHess() >= 0.0f) {
|
||||||
sum_grad += p.grad; sum_hess += p.hess;
|
sum_grad += p.GetGrad();
|
||||||
|
sum_hess += p.GetHess();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// remove bias effect
|
// remove bias effect
|
||||||
@ -132,8 +133,8 @@ class GBLinear : public GradientBooster {
|
|||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
for (bst_omp_uint i = 0; i < ndata; ++i) {
|
||||||
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
bst_gpair &p = gpair[rowset[i] * ngroup + gid];
|
||||||
if (p.hess >= 0.0f) {
|
if (p.GetHess() >= 0.0f) {
|
||||||
p.grad += p.hess * dw;
|
p += bst_gpair(p.GetHess() * dw, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -151,9 +152,9 @@ class GBLinear : public GradientBooster {
|
|||||||
for (bst_uint j = 0; j < col.length; ++j) {
|
for (bst_uint j = 0; j < col.length; ++j) {
|
||||||
const bst_float v = col[j].fvalue;
|
const bst_float v = col[j].fvalue;
|
||||||
bst_gpair &p = gpair[col[j].index * ngroup + gid];
|
bst_gpair &p = gpair[col[j].index * ngroup + gid];
|
||||||
if (p.hess < 0.0f) continue;
|
if (p.GetHess() < 0.0f) continue;
|
||||||
sum_grad += p.grad * v;
|
sum_grad += p.GetGrad() * v;
|
||||||
sum_hess += p.hess * v * v;
|
sum_hess += p.GetHess() * v * v;
|
||||||
}
|
}
|
||||||
bst_float &w = model[fid][gid];
|
bst_float &w = model[fid][gid];
|
||||||
bst_float dw = static_cast<bst_float>(param.learning_rate *
|
bst_float dw = static_cast<bst_float>(param.learning_rate *
|
||||||
@ -162,8 +163,8 @@ class GBLinear : public GradientBooster {
|
|||||||
// update grad value
|
// update grad value
|
||||||
for (bst_uint j = 0; j < col.length; ++j) {
|
for (bst_uint j = 0; j < col.length; ++j) {
|
||||||
bst_gpair &p = gpair[col[j].index * ngroup + gid];
|
bst_gpair &p = gpair[col[j].index * ngroup + gid];
|
||||||
if (p.hess < 0.0f) continue;
|
if (p.GetHess() < 0.0f) continue;
|
||||||
p.grad += p.hess * col[j].fvalue * dw;
|
p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -109,10 +109,8 @@ class LambdaRankObj : public ObjFunction {
|
|||||||
bst_float g = p - 1.0f;
|
bst_float g = p - 1.0f;
|
||||||
bst_float h = std::max(p * (1.0f - p), eps);
|
bst_float h = std::max(p * (1.0f - p), eps);
|
||||||
// accumulate gradient and hessian in both pid, and nid
|
// accumulate gradient and hessian in both pid, and nid
|
||||||
gpair[pos.rindex].grad += g * w;
|
gpair[pos.rindex] += bst_gpair(g * w, 2.0f*w*h);
|
||||||
gpair[pos.rindex].hess += 2.0f * w * h;
|
gpair[neg.rindex] += bst_gpair(-g * w, 2.0f*w*h);
|
||||||
gpair[neg.rindex].grad -= g * w;
|
|
||||||
gpair[neg.rindex].hess += 2.0f * w * h;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -313,7 +313,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
* \brief accumulate statistics
|
* \brief accumulate statistics
|
||||||
* \param p the gradient pair
|
* \param p the gradient pair
|
||||||
*/
|
*/
|
||||||
inline void Add(bst_gpair p) { this->Add(p.grad, p.hess); }
|
inline void Add(bst_gpair p) { this->Add(p.GetGrad(), p.GetHess()); }
|
||||||
/*!
|
/*!
|
||||||
* \brief accumulate statistics, more complicated version
|
* \brief accumulate statistics, more complicated version
|
||||||
* \param gpair the vector storing the gradient statistics
|
* \param gpair the vector storing the gradient statistics
|
||||||
@ -323,7 +323,7 @@ struct XGBOOST_ALIGNAS(16) GradStats {
|
|||||||
inline void Add(const std::vector<bst_gpair>& gpair, const MetaInfo& info,
|
inline void Add(const std::vector<bst_gpair>& gpair, const MetaInfo& info,
|
||||||
bst_uint ridx) {
|
bst_uint ridx) {
|
||||||
const bst_gpair& b = gpair[ridx];
|
const bst_gpair& b = gpair[ridx];
|
||||||
this->Add(b.grad, b.hess);
|
this->Add(b.GetGrad(), b.GetHess());
|
||||||
}
|
}
|
||||||
/*! \brief calculate leaf weight */
|
/*! \brief calculate leaf weight */
|
||||||
inline double CalcWeight(const TrainParam& param) const {
|
inline double CalcWeight(const TrainParam& param) const {
|
||||||
|
|||||||
@ -140,14 +140,14 @@ class BaseMaker: public TreeUpdater {
|
|||||||
}
|
}
|
||||||
// mark delete for the deleted datas
|
// mark delete for the deleted datas
|
||||||
for (size_t i = 0; i < position.size(); ++i) {
|
for (size_t i = 0; i < position.size(); ++i) {
|
||||||
if (gpair[i].hess < 0.0f) position[i] = ~position[i];
|
if (gpair[i].GetHess() < 0.0f) position[i] = ~position[i];
|
||||||
}
|
}
|
||||||
// mark subsample
|
// mark subsample
|
||||||
if (param.subsample < 1.0f) {
|
if (param.subsample < 1.0f) {
|
||||||
std::bernoulli_distribution coin_flip(param.subsample);
|
std::bernoulli_distribution coin_flip(param.subsample);
|
||||||
auto& rnd = common::GlobalRandom();
|
auto& rnd = common::GlobalRandom();
|
||||||
for (size_t i = 0; i < position.size(); ++i) {
|
for (size_t i = 0; i < position.size(); ++i) {
|
||||||
if (gpair[i].hess < 0.0f) continue;
|
if (gpair[i].GetHess() < 0.0f) continue;
|
||||||
if (!coin_flip(rnd)) position[i] = ~position[i];
|
if (!coin_flip(rnd)) position[i] = ~position[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -136,7 +136,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
// mark delete for the deleted datas
|
// mark delete for the deleted datas
|
||||||
for (size_t i = 0; i < rowset.size(); ++i) {
|
for (size_t i = 0; i < rowset.size(); ++i) {
|
||||||
const bst_uint ridx = rowset[i];
|
const bst_uint ridx = rowset[i];
|
||||||
if (gpair[ridx].hess < 0.0f) position[ridx] = ~position[ridx];
|
if (gpair[ridx].GetHess() < 0.0f) position[ridx] = ~position[ridx];
|
||||||
}
|
}
|
||||||
// mark subsample
|
// mark subsample
|
||||||
if (param.subsample < 1.0f) {
|
if (param.subsample < 1.0f) {
|
||||||
@ -144,7 +144,7 @@ class ColMaker: public TreeUpdater {
|
|||||||
auto& rnd = common::GlobalRandom();
|
auto& rnd = common::GlobalRandom();
|
||||||
for (size_t i = 0; i < rowset.size(); ++i) {
|
for (size_t i = 0; i < rowset.size(); ++i) {
|
||||||
const bst_uint ridx = rowset[i];
|
const bst_uint ridx = rowset[i];
|
||||||
if (gpair[ridx].hess < 0.0f) continue;
|
if (gpair[ridx].GetHess() < 0.0f) continue;
|
||||||
if (!coin_flip(rnd)) position[ridx] = ~position[ridx];
|
if (!coin_flip(rnd)) position[ridx] = ~position[ridx];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -372,13 +372,13 @@ class FastHistMaker: public TreeUpdater {
|
|||||||
std::bernoulli_distribution coin_flip(param.subsample);
|
std::bernoulli_distribution coin_flip(param.subsample);
|
||||||
auto& rnd = common::GlobalRandom();
|
auto& rnd = common::GlobalRandom();
|
||||||
for (size_t i = 0; i < info.num_row; ++i) {
|
for (size_t i = 0; i < info.num_row; ++i) {
|
||||||
if (gpair[i].hess >= 0.0f && coin_flip(rnd)) {
|
if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) {
|
||||||
row_indices.push_back(i);
|
row_indices.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < info.num_row; ++i) {
|
for (size_t i = 0; i < info.num_row; ++i) {
|
||||||
if (gpair[i].hess >= 0.0f) {
|
if (gpair[i].GetHess() >= 0.0f) {
|
||||||
row_indices.push_back(i);
|
row_indices.push_back(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -82,8 +82,8 @@ struct DeviceDenseNode {
|
|||||||
fvalue(0.f),
|
fvalue(0.f),
|
||||||
fidx(UNUSED_NODE),
|
fidx(UNUSED_NODE),
|
||||||
idx(nidx) {
|
idx(nidx) {
|
||||||
this->root_gain = CalcGain(param, sum_gradients.grad, sum_gradients.hess);
|
this->root_gain = CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
|
||||||
this->weight = CalcWeight(param, sum_gradients.grad, sum_gradients.hess);
|
this->weight = CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess());
|
||||||
}
|
}
|
||||||
|
|
||||||
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) {
|
HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) {
|
||||||
@ -113,8 +113,8 @@ __device__ inline float device_calc_loss_chg(
|
|||||||
|
|
||||||
gpair_t right = parent_sum - left;
|
gpair_t right = parent_sum - left;
|
||||||
|
|
||||||
float left_gain = CalcGain(param, left.grad, left.hess);
|
float left_gain = CalcGain(param, left.GetGrad(), left.GetHess());
|
||||||
float right_gain = CalcGain(param, right.grad, right.hess);
|
float right_gain = CalcGain(param, right.GetGrad(), right.GetHess());
|
||||||
return left_gain + right_gain - parent_gain;
|
return left_gain + right_gain - parent_gain;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,13 +181,13 @@ inline void dense2sparse_tree(RegTree* p_tree,
|
|||||||
tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir);
|
tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir);
|
||||||
tree.stat(nid).loss_chg = n.root_gain;
|
tree.stat(nid).loss_chg = n.root_gain;
|
||||||
tree.stat(nid).base_weight = n.weight;
|
tree.stat(nid).base_weight = n.weight;
|
||||||
tree.stat(nid).sum_hess = n.sum_gradients.hess;
|
tree.stat(nid).sum_hess = n.sum_gradients.GetHess();
|
||||||
tree[tree[nid].cleft()].set_leaf(0);
|
tree[tree[nid].cleft()].set_leaf(0);
|
||||||
tree[tree[nid].cright()].set_leaf(0);
|
tree[tree[nid].cright()].set_leaf(0);
|
||||||
nid++;
|
nid++;
|
||||||
} else if (n.IsLeaf()) {
|
} else if (n.IsLeaf()) {
|
||||||
tree[nid].set_leaf(n.weight * param.learning_rate);
|
tree[nid].set_leaf(n.weight * param.learning_rate);
|
||||||
tree.stat(nid).sum_hess = n.sum_gradients.hess;
|
tree.stat(nid).sum_hess = n.sum_gradients.GetHess();
|
||||||
nid++;
|
nid++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,17 +5,20 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "param.h"
|
|
||||||
#include "../common/compressed_iterator.h"
|
#include "../common/compressed_iterator.h"
|
||||||
#include "../common/hist_util.h"
|
|
||||||
#include "updater_gpu_common.cuh"
|
|
||||||
#include "../common/device_helpers.cuh"
|
#include "../common/device_helpers.cuh"
|
||||||
|
#include "../common/hist_util.h"
|
||||||
|
#include "param.h"
|
||||||
|
#include "updater_gpu_common.cuh"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace tree {
|
namespace tree {
|
||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
|
||||||
|
|
||||||
|
typedef bst_gpair_integer gpair_sum_t;
|
||||||
|
static const ncclDataType_t nccl_sum_t = ncclInt64;
|
||||||
|
|
||||||
// Helper for explicit template specialisation
|
// Helper for explicit template specialisation
|
||||||
template <int N>
|
template <int N>
|
||||||
struct Int {};
|
struct Int {};
|
||||||
@ -50,27 +53,29 @@ struct DeviceGMat {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct HistHelper {
|
struct HistHelper {
|
||||||
bst_gpair* d_hist;
|
gpair_sum_t* d_hist;
|
||||||
int n_bins;
|
int n_bins;
|
||||||
__host__ __device__ HistHelper(bst_gpair* ptr, int n_bins)
|
__host__ __device__ HistHelper(gpair_sum_t* ptr, int n_bins)
|
||||||
: d_hist(ptr), n_bins(n_bins) {}
|
: d_hist(ptr), n_bins(n_bins) {}
|
||||||
|
|
||||||
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const {
|
__device__ void Add(bst_gpair gpair, int gidx, int nidx) const {
|
||||||
int hist_idx = nidx * n_bins + gidx;
|
int hist_idx = nidx * n_bins + gidx;
|
||||||
atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below
|
|
||||||
// line lead to about 3X
|
auto dst_ptr = reinterpret_cast<unsigned long long int*>(&d_hist[hist_idx]); // NOLINT
|
||||||
// slowdown due to memory
|
gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess());
|
||||||
// dependency and access
|
auto src_ptr = reinterpret_cast<gpair_sum_t::value_t*>(&tmp);
|
||||||
// pattern issues.
|
|
||||||
atomicAdd(&(d_hist[hist_idx].hess), gpair.hess);
|
atomicAdd(dst_ptr, static_cast<unsigned long long int>(*src_ptr)); // NOLINT
|
||||||
|
atomicAdd(dst_ptr + 1, static_cast<unsigned long long int>(*(src_ptr + 1))); // NOLINT
|
||||||
}
|
}
|
||||||
__device__ bst_gpair Get(int gidx, int nidx) const {
|
__device__ gpair_sum_t Get(int gidx, int nidx) const {
|
||||||
return d_hist[nidx * n_bins + gidx];
|
return d_hist[nidx * n_bins + gidx];
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DeviceHist {
|
struct DeviceHist {
|
||||||
int n_bins;
|
int n_bins;
|
||||||
dh::dvec<bst_gpair> data;
|
dh::dvec<gpair_sum_t> data;
|
||||||
|
|
||||||
void Init(int n_bins_in) {
|
void Init(int n_bins_in) {
|
||||||
this->n_bins = n_bins_in;
|
this->n_bins = n_bins_in;
|
||||||
@ -79,12 +84,12 @@ struct DeviceHist {
|
|||||||
|
|
||||||
void Reset(int device_idx) {
|
void Reset(int device_idx) {
|
||||||
cudaSetDevice(device_idx);
|
cudaSetDevice(device_idx);
|
||||||
data.fill(bst_gpair());
|
data.fill(gpair_sum_t());
|
||||||
}
|
}
|
||||||
|
|
||||||
HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); }
|
HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); }
|
||||||
|
|
||||||
bst_gpair* GetLevelPtr(int depth) {
|
gpair_sum_t* GetLevelPtr(int depth) {
|
||||||
return data.data() + n_nodes(depth - 1) * n_bins;
|
return data.data() + n_nodes(depth - 1) * n_bins;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,18 +101,19 @@ struct SplitCandidate {
|
|||||||
bool missing_left;
|
bool missing_left;
|
||||||
float fvalue;
|
float fvalue;
|
||||||
int findex;
|
int findex;
|
||||||
bst_gpair left_sum;
|
gpair_sum_t left_sum;
|
||||||
bst_gpair right_sum;
|
gpair_sum_t right_sum;
|
||||||
|
|
||||||
__host__ __device__ SplitCandidate()
|
__host__ __device__ SplitCandidate()
|
||||||
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
: loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {}
|
||||||
|
|
||||||
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
__device__ void Update(float loss_chg_in, bool missing_left_in,
|
||||||
float fvalue_in, int findex_in, bst_gpair left_sum_in,
|
float fvalue_in, int findex_in,
|
||||||
bst_gpair right_sum_in,
|
gpair_sum_t left_sum_in, gpair_sum_t right_sum_in,
|
||||||
const GPUTrainingParam& param) {
|
const GPUTrainingParam& param) {
|
||||||
if (loss_chg_in > loss_chg && left_sum_in.hess >= param.min_child_weight &&
|
if (loss_chg_in > loss_chg &&
|
||||||
right_sum_in.hess >= param.min_child_weight) {
|
left_sum_in.GetHess() >= param.min_child_weight &&
|
||||||
|
right_sum_in.GetHess() >= param.min_child_weight) {
|
||||||
loss_chg = loss_chg_in;
|
loss_chg = loss_chg_in;
|
||||||
missing_left = missing_left_in;
|
missing_left = missing_left_in;
|
||||||
fvalue = fvalue_in;
|
fvalue = fvalue_in;
|
||||||
@ -121,11 +127,11 @@ struct SplitCandidate {
|
|||||||
|
|
||||||
struct GpairCallbackOp {
|
struct GpairCallbackOp {
|
||||||
// Running prefix
|
// Running prefix
|
||||||
bst_gpair running_total;
|
gpair_sum_t running_total;
|
||||||
// Constructor
|
// Constructor
|
||||||
__device__ GpairCallbackOp() : running_total(bst_gpair()) {}
|
__device__ GpairCallbackOp() : running_total(gpair_sum_t()) {}
|
||||||
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
|
__device__ bst_gpair operator()(bst_gpair block_aggregate) {
|
||||||
bst_gpair old_prefix = running_total;
|
gpair_sum_t old_prefix = running_total;
|
||||||
running_total += block_aggregate;
|
running_total += block_aggregate;
|
||||||
return old_prefix;
|
return old_prefix;
|
||||||
}
|
}
|
||||||
@ -133,17 +139,16 @@ struct GpairCallbackOp {
|
|||||||
|
|
||||||
template <int BLOCK_THREADS>
|
template <int BLOCK_THREADS>
|
||||||
__global__ void find_split_kernel(
|
__global__ void find_split_kernel(
|
||||||
const bst_gpair* d_level_hist, int* d_feature_segments, int depth,
|
const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth,
|
||||||
int n_features, int n_bins, DeviceDenseNode* d_nodes,
|
int n_features, int n_bins, DeviceDenseNode* d_nodes,
|
||||||
int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map,
|
int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map,
|
||||||
GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp,
|
GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp,
|
||||||
bool colsample, int* d_feature_flags) {
|
bool colsample, int* d_feature_flags) {
|
||||||
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
typedef cub::KeyValuePair<int, float> ArgMaxT;
|
||||||
typedef cub::BlockScan<bst_gpair, BLOCK_THREADS,
|
typedef cub::BlockScan<gpair_sum_t, BLOCK_THREADS, cub::BLOCK_SCAN_WARP_SCANS>
|
||||||
cub::BLOCK_SCAN_WARP_SCANS>
|
|
||||||
BlockScanT;
|
BlockScanT;
|
||||||
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
typedef cub::BlockReduce<ArgMaxT, BLOCK_THREADS> MaxReduceT;
|
||||||
typedef cub::BlockReduce<bst_gpair, BLOCK_THREADS> SumReduceT;
|
typedef cub::BlockReduce<gpair_sum_t, BLOCK_THREADS> SumReduceT;
|
||||||
|
|
||||||
union TempStorage {
|
union TempStorage {
|
||||||
typename BlockScanT::TempStorage scan;
|
typename BlockScanT::TempStorage scan;
|
||||||
@ -153,8 +158,8 @@ __global__ void find_split_kernel(
|
|||||||
|
|
||||||
__shared__ cub::Uninitialized<SplitCandidate> uninitialized_split;
|
__shared__ cub::Uninitialized<SplitCandidate> uninitialized_split;
|
||||||
SplitCandidate& split = uninitialized_split.Alias();
|
SplitCandidate& split = uninitialized_split.Alias();
|
||||||
__shared__ cub::Uninitialized<bst_gpair> uninitialized_sum;
|
__shared__ cub::Uninitialized<gpair_sum_t> uninitialized_sum;
|
||||||
bst_gpair& shared_sum = uninitialized_sum.Alias();
|
gpair_sum_t& shared_sum = uninitialized_sum.Alias();
|
||||||
__shared__ ArgMaxT block_max;
|
__shared__ ArgMaxT block_max;
|
||||||
__shared__ TempStorage temp_storage;
|
__shared__ TempStorage temp_storage;
|
||||||
|
|
||||||
@ -175,14 +180,13 @@ __global__ void find_split_kernel(
|
|||||||
int begin = d_feature_segments[level_node_idx * n_features + fidx];
|
int begin = d_feature_segments[level_node_idx * n_features + fidx];
|
||||||
int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
|
int end = d_feature_segments[level_node_idx * n_features + fidx + 1];
|
||||||
|
|
||||||
bst_gpair feature_sum = bst_gpair();
|
gpair_sum_t feature_sum = gpair_sum_t();
|
||||||
for (int reduce_begin = begin; reduce_begin < end;
|
for (int reduce_begin = begin; reduce_begin < end;
|
||||||
reduce_begin += BLOCK_THREADS) {
|
reduce_begin += BLOCK_THREADS) {
|
||||||
bool thread_active = reduce_begin + threadIdx.x < end;
|
bool thread_active = reduce_begin + threadIdx.x < end;
|
||||||
// Scan histogram
|
// Scan histogram
|
||||||
bst_gpair bin = thread_active
|
gpair_sum_t bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x]
|
||||||
? d_level_hist[reduce_begin + threadIdx.x]
|
: gpair_sum_t();
|
||||||
: bst_gpair();
|
|
||||||
|
|
||||||
feature_sum +=
|
feature_sum +=
|
||||||
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum());
|
||||||
@ -197,18 +201,17 @@ __global__ void find_split_kernel(
|
|||||||
for (int scan_begin = begin; scan_begin < end;
|
for (int scan_begin = begin; scan_begin < end;
|
||||||
scan_begin += BLOCK_THREADS) {
|
scan_begin += BLOCK_THREADS) {
|
||||||
bool thread_active = scan_begin + threadIdx.x < end;
|
bool thread_active = scan_begin + threadIdx.x < end;
|
||||||
bst_gpair bin = thread_active
|
gpair_sum_t bin = thread_active ? d_level_hist[scan_begin + threadIdx.x]
|
||||||
? d_level_hist[scan_begin + threadIdx.x]
|
: gpair_sum_t();
|
||||||
: bst_gpair();
|
|
||||||
|
|
||||||
BlockScanT(temp_storage.scan)
|
BlockScanT(temp_storage.scan)
|
||||||
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
.ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
|
||||||
|
|
||||||
// Calculate gain
|
// Calculate gain
|
||||||
bst_gpair parent_sum = d_nodes[node_idx].sum_gradients;
|
gpair_sum_t parent_sum = gpair_sum_t(d_nodes[node_idx].sum_gradients);
|
||||||
float parent_gain = d_nodes[node_idx].root_gain;
|
float parent_gain = d_nodes[node_idx].root_gain;
|
||||||
|
|
||||||
bst_gpair missing = parent_sum - shared_sum;
|
gpair_sum_t missing = parent_sum - shared_sum;
|
||||||
|
|
||||||
bool missing_left;
|
bool missing_left;
|
||||||
float gain = thread_active
|
float gain = thread_active
|
||||||
@ -239,8 +242,8 @@ __global__ void find_split_kernel(
|
|||||||
fvalue = d_gidx_fvalue_map[gidx - 1];
|
fvalue = d_gidx_fvalue_map[gidx - 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_gpair left = missing_left ? bin + missing : bin;
|
gpair_sum_t left = missing_left ? bin + missing : bin;
|
||||||
bst_gpair right = parent_sum - left;
|
gpair_sum_t right = parent_sum - left;
|
||||||
|
|
||||||
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param);
|
||||||
}
|
}
|
||||||
@ -263,7 +266,7 @@ __global__ void find_split_kernel(
|
|||||||
DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param);
|
DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param);
|
||||||
|
|
||||||
// Record smallest node
|
// Record smallest node
|
||||||
if (split.left_sum.hess <= split.right_sum.hess) {
|
if (split.left_sum.GetHess() <= split.right_sum.GetHess()) {
|
||||||
left_child_smallest = true;
|
left_child_smallest = true;
|
||||||
} else {
|
} else {
|
||||||
left_child_smallest = false;
|
left_child_smallest = false;
|
||||||
@ -595,6 +598,7 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
initialised = true;
|
initialised = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void BuildHist(int depth) {
|
void BuildHist(int depth) {
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
int device_idx = dList[d_idx];
|
int device_idx = dList[d_idx];
|
||||||
@ -650,9 +654,9 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
dh::safe_nccl(ncclAllReduce(
|
dh::safe_nccl(ncclAllReduce(
|
||||||
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
reinterpret_cast<const void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
reinterpret_cast<void*>(hist_vec[d_idx].GetLevelPtr(depth)),
|
||||||
hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) /
|
hist_vec[d_idx].LevelSize(depth) * sizeof(gpair_sum_t) /
|
||||||
sizeof(float),
|
sizeof(gpair_sum_t::value_t),
|
||||||
ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx])));
|
nccl_sum_t, ncclSum, comms[d_idx], *(streams[d_idx])));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
for (int d_idx = 0; d_idx < n_devices; d_idx++) {
|
||||||
@ -683,11 +687,12 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int gidx = idx % hist_builder.n_bins;
|
int gidx = idx % hist_builder.n_bins;
|
||||||
bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
gpair_sum_t parent = hist_builder.Get(gidx, parent_nidx(nidx));
|
||||||
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
int other_nidx = left_smallest ? nidx - 1 : nidx + 1;
|
||||||
bst_gpair other = hist_builder.Get(gidx, other_nidx);
|
gpair_sum_t other = hist_builder.Get(gidx, other_nidx);
|
||||||
|
gpair_sum_t sub = parent - other;
|
||||||
hist_builder.Add(
|
hist_builder.Add(
|
||||||
parent - other, gidx,
|
bst_gpair(sub.GetGrad(), sub.GetHess()), gidx,
|
||||||
nidx); // OPTMARK: This is slow, could use shared
|
nidx); // OPTMARK: This is slow, could use shared
|
||||||
// memory or cache results intead of writing to
|
// memory or cache results intead of writing to
|
||||||
// global memory every time in atomic way.
|
// global memory every time in atomic way.
|
||||||
@ -737,11 +742,11 @@ class GPUHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
int nodes_offset_device = 0;
|
int nodes_offset_device = 0;
|
||||||
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
find_split_kernel<BLOCK_THREADS><<<GRID_SIZE, BLOCK_THREADS>>>(
|
||||||
(const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)),
|
hist_vec[d_idx].GetLevelPtr(depth), feature_segments[d_idx].data(),
|
||||||
feature_segments[d_idx].data(), depth, (info->num_col),
|
depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(),
|
||||||
(hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_offset_device,
|
nodes_offset_device, fidx_min_map[d_idx].data(),
|
||||||
fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(),
|
gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param),
|
||||||
GPUTrainingParam(param), left_child_smallest[d_idx].data(), colsample,
|
left_child_smallest[d_idx].data(), colsample,
|
||||||
feature_flags[d_idx].data());
|
feature_flags[d_idx].data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -568,7 +568,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
sbuilder[nid].sum_total += gpair[ridx].hess;
|
sbuilder[nid].sum_total += gpair[ridx].GetHess();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if only one value, no need to do second pass
|
// if only one value, no need to do second pass
|
||||||
@ -595,7 +595,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
for (bst_uint i = 0; i < kBuffer; ++i) {
|
for (bst_uint i = 0; i < kBuffer; ++i) {
|
||||||
bst_uint ridx = c[j + i].index;
|
bst_uint ridx = c[j + i].index;
|
||||||
buf_position[i] = this->position[ridx];
|
buf_position[i] = this->position[ridx];
|
||||||
buf_hess[i] = gpair[ridx].hess;
|
buf_hess[i] = gpair[ridx].GetHess();
|
||||||
}
|
}
|
||||||
for (bst_uint i = 0; i < kBuffer; ++i) {
|
for (bst_uint i = 0; i < kBuffer; ++i) {
|
||||||
const int nid = buf_position[i];
|
const int nid = buf_position[i];
|
||||||
@ -608,7 +608,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size);
|
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].GetHess(), max_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -616,7 +616,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
|||||||
const bst_uint ridx = c[j].index;
|
const bst_uint ridx = c[j].index;
|
||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size);
|
sbuilder[nid].Push(c[j].fvalue, gpair[ridx].GetHess(), max_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -818,7 +818,7 @@ class QuantileHistMaker: public HistMaker<TStats> {
|
|||||||
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
|
for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) {
|
||||||
const SparseBatch::Entry &e = col_data[i];
|
const SparseBatch::Entry &e = col_data[i];
|
||||||
const int wid = this->node2workindex[e.index];
|
const int wid = this->node2workindex[e.index];
|
||||||
sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].hess);
|
sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -98,12 +98,12 @@ class SketchMaker: public BaseMaker {
|
|||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
bst_uint ridx) {
|
bst_uint ridx) {
|
||||||
const bst_gpair &b = gpair[ridx];
|
const bst_gpair &b = gpair[ridx];
|
||||||
if (b.grad >= 0.0f) {
|
if (b.GetGrad() >= 0.0f) {
|
||||||
pos_grad += b.grad;
|
pos_grad += b.GetGrad();
|
||||||
} else {
|
} else {
|
||||||
neg_grad -= b.grad;
|
neg_grad -= b.GetGrad();
|
||||||
}
|
}
|
||||||
sum_hess += b.hess;
|
sum_hess += b.GetHess();
|
||||||
}
|
}
|
||||||
/*! \brief calculate gain of the solution */
|
/*! \brief calculate gain of the solution */
|
||||||
inline double CalcGain(const TrainParam ¶m) const {
|
inline double CalcGain(const TrainParam ¶m) const {
|
||||||
@ -199,12 +199,12 @@ class SketchMaker: public BaseMaker {
|
|||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
const bst_gpair &e = gpair[ridx];
|
const bst_gpair &e = gpair[ridx];
|
||||||
if (e.grad >= 0.0f) {
|
if (e.GetGrad() >= 0.0f) {
|
||||||
sbuilder[3 * nid + 0].sum_total += e.grad;
|
sbuilder[3 * nid + 0].sum_total += e.GetGrad();
|
||||||
} else {
|
} else {
|
||||||
sbuilder[3 * nid + 1].sum_total -= e.grad;
|
sbuilder[3 * nid + 1].sum_total -= e.GetGrad();
|
||||||
}
|
}
|
||||||
sbuilder[3 * nid + 2].sum_total += e.hess;
|
sbuilder[3 * nid + 2].sum_total += e.GetHess();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -241,12 +241,12 @@ class SketchMaker: public BaseMaker {
|
|||||||
const int nid = this->position[ridx];
|
const int nid = this->position[ridx];
|
||||||
if (nid >= 0) {
|
if (nid >= 0) {
|
||||||
const bst_gpair &e = gpair[ridx];
|
const bst_gpair &e = gpair[ridx];
|
||||||
if (e.grad >= 0.0f) {
|
if (e.GetGrad() >= 0.0f) {
|
||||||
sbuilder[3 * nid + 0].Push(c[j].fvalue, e.grad, max_size);
|
sbuilder[3 * nid + 0].Push(c[j].fvalue, e.GetGrad(), max_size);
|
||||||
} else {
|
} else {
|
||||||
sbuilder[3 * nid + 1].Push(c[j].fvalue, -e.grad, max_size);
|
sbuilder[3 * nid + 1].Push(c[j].fvalue, -e.GetGrad(), max_size);
|
||||||
}
|
}
|
||||||
sbuilder[3 * nid + 2].Push(c[j].fvalue, e.hess, max_size);
|
sbuilder[3 * nid + 2].Push(c[j].fvalue, e.GetHess(), max_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
for (size_t i = 0; i < this->qexpand.size(); ++i) {
|
||||||
|
|||||||
@ -43,10 +43,10 @@ void CheckObjFunction(xgboost::ObjFunction * obj,
|
|||||||
|
|
||||||
ASSERT_EQ(gpair.size(), preds.size());
|
ASSERT_EQ(gpair.size(), preds.size());
|
||||||
for (int i = 0; i < static_cast<int>(gpair.size()); ++i) {
|
for (int i = 0; i < static_cast<int>(gpair.size()); ++i) {
|
||||||
EXPECT_NEAR(gpair[i].grad, out_grad[i], 0.01)
|
EXPECT_NEAR(gpair[i].GetGrad(), out_grad[i], 0.01)
|
||||||
<< "Unexpected grad for pred=" << preds[i] << " label=" << labels[i]
|
<< "Unexpected grad for pred=" << preds[i] << " label=" << labels[i]
|
||||||
<< " weight=" << weights[i];
|
<< " weight=" << weights[i];
|
||||||
EXPECT_NEAR(gpair[i].hess, out_hess[i], 0.01)
|
EXPECT_NEAR(gpair[i].GetHess(), out_hess[i], 0.01)
|
||||||
<< "Unexpected hess for pred=" << preds[i] << " label=" << labels[i]
|
<< "Unexpected hess for pred=" << preds[i] << " label=" << labels[i]
|
||||||
<< " weight=" << weights[i];
|
<< " weight=" << weights[i];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -16,7 +16,7 @@ TEST(gpu_predictor, Test) {
|
|||||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor"));
|
||||||
|
|
||||||
std::vector<std::unique_ptr<RegTree>> trees;
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
trees.push_back(std::make_unique<RegTree>());
|
trees.push_back(std::unique_ptr<RegTree>());
|
||||||
trees.back()->InitModel();
|
trees.back()->InitModel();
|
||||||
(*trees.back())[0].set_leaf(1.5f);
|
(*trees.back())[0].set_leaf(1.5f);
|
||||||
gbm::GBTreeModel model(0.5);
|
gbm::GBTreeModel model(0.5);
|
||||||
|
|||||||
@ -14,7 +14,6 @@ UNITTEST_DEPS=lib/libxgboost.a $(DMLC_CORE)/libdmlc.a $(RABIT)/lib/$(LIB_RABIT)
|
|||||||
|
|
||||||
COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ))
|
COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ))
|
||||||
|
|
||||||
# the order of the below targets matter!
|
|
||||||
$(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc
|
$(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc
|
||||||
@mkdir -p $(@D)
|
@mkdir -p $(@D)
|
||||||
$(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $<
|
$(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $<
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user