From c81238b5c43b57cad037348f858f26d14b7e906e Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 17 Aug 2019 01:05:57 -0400 Subject: [PATCH] Clean up after removing `gpu_exact`. (#4777) * Removed unused functions. * Removed unused parameters. * Move ValueConstraints into constraints.cuh since it's now only used in GPU_Hist. --- src/tree/constraints.cuh | 71 +++++++++++++++++++++++++++++ src/tree/param.h | 79 --------------------------------- src/tree/updater_gpu_common.cuh | 75 +++---------------------------- 3 files changed, 78 insertions(+), 147 deletions(-) diff --git a/src/tree/constraints.cuh b/src/tree/constraints.cuh index e30530c70..0ee901796 100644 --- a/src/tree/constraints.cuh +++ b/src/tree/constraints.cuh @@ -1,5 +1,7 @@ /*! * Copyright 2019 XGBoost contributors + * + * \file Various constraints used in GPU_Hist. */ #ifndef XGBOOST_TREE_CONSTRAINTS_H_ #define XGBOOST_TREE_CONSTRAINTS_H_ @@ -16,6 +18,75 @@ namespace xgboost { +// This class implements monotonic constraints, L1, L2 regularization. +struct ValueConstraint { + double lower_bound; + double upper_bound; + XGBOOST_DEVICE ValueConstraint() + : lower_bound(-std::numeric_limits::max()), + upper_bound(std::numeric_limits::max()) {} + inline static void Init(tree::TrainParam *param, unsigned num_feature) { + param->monotone_constraints.resize(num_feature, 0); + } + template + XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m, tree::GradStats stats) const { + double w = xgboost::tree::CalcWeight(param, stats); + if (w < lower_bound) { + return lower_bound; + } + if (w > upper_bound) { + return upper_bound; + } + return w; + } + + template + XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, tree::GradStats stats) const { + return tree::CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, + CalcWeight(param, stats)); + } + + template + XGBOOST_DEVICE inline double CalcSplitGain(const ParamT ¶m, int constraint, + tree::GradStats left, tree::GradStats right) const { + const double negative_infinity = -std::numeric_limits::infinity(); + double wleft = CalcWeight(param, left); + double wright = CalcWeight(param, right); + double gain = + tree::CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) + + tree::CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright); + if (constraint == 0) { + return gain; + } else if (constraint > 0) { + return wleft <= wright ? gain : negative_infinity; + } else { + return wleft >= wright ? gain : negative_infinity; + } + } + + inline void SetChild(const tree::TrainParam ¶m, bst_uint split_index, + tree::GradStats left, tree::GradStats right, ValueConstraint *cleft, + ValueConstraint *cright) { + int c = param.monotone_constraints.at(split_index); + *cleft = *this; + *cright = *this; + if (c == 0) { + return; + } + double wleft = CalcWeight(param, left); + double wright = CalcWeight(param, right); + double mid = (wleft + wright) / 2; + CHECK(!std::isnan(mid)); + if (c < 0) { + cleft->lower_bound = mid; + cright->upper_bound = mid; + } else { + cleft->upper_bound = mid; + cright->lower_bound = mid; + } + } +}; + // Feature interaction constraints built for GPU Hist updater. struct FeatureInteractionConstraint { protected: diff --git a/src/tree/param.h b/src/tree/param.h index 93f0797c6..96914a0ef 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -62,8 +62,6 @@ struct TrainParam : public dmlc::Parameter { float sketch_eps; // accuracy of sketch float sketch_ratio; - // leaf vector size - int size_leaf_vector; // option for parallelization int parallel_option; // option to open cacheline optimization @@ -176,10 +174,6 @@ struct TrainParam : public dmlc::Parameter { .set_lower_bound(0.0f) .set_default(2.0f) .describe("EXP Param: Sketch accuracy related parameter of approximate algorithm."); - DMLC_DECLARE_FIELD(size_leaf_vector) - .set_lower_bound(0) - .set_default(0) - .describe("Size of leaf vectors, reserved for vector trees"); DMLC_DECLARE_FIELD(parallel_option) .set_default(0) .describe("Different types of parallelization algorithm."); @@ -240,10 +234,6 @@ struct TrainParam : public dmlc::Parameter { inline bool NeedPrune(double loss_chg, int depth) const { return loss_chg < this->min_split_loss; } - /*! \brief whether we can split with current hessian */ - inline bool CannotSplit(double sum_hess, int depth) const { - return sum_hess < this->min_child_weight * 2.0; - } /*! \brief maximum sketch size */ inline unsigned MaxSketchSize() const { auto ret = static_cast(sketch_ratio / sketch_eps); @@ -400,75 +390,6 @@ struct GradStats { } }; -// TODO(trivialfis): Remove this class. -struct ValueConstraint { - double lower_bound; - double upper_bound; - XGBOOST_DEVICE ValueConstraint() - : lower_bound(-std::numeric_limits::max()), - upper_bound(std::numeric_limits::max()) {} - inline static void Init(TrainParam *param, unsigned num_feature) { - param->monotone_constraints.resize(num_feature, 0); - } - template - XGBOOST_DEVICE inline double CalcWeight(const ParamT ¶m, GradStats stats) const { - double w = xgboost::tree::CalcWeight(param, stats); - if (w < lower_bound) { - return lower_bound; - } - if (w > upper_bound) { - return upper_bound; - } - return w; - } - - template - XGBOOST_DEVICE inline double CalcGain(const ParamT ¶m, GradStats stats) const { - return CalcGainGivenWeight(param, stats.sum_grad, stats.sum_hess, - CalcWeight(param, stats)); - } - - template - XGBOOST_DEVICE inline double CalcSplitGain(const ParamT ¶m, int constraint, - GradStats left, GradStats right) const { - const double negative_infinity = -std::numeric_limits::infinity(); - double wleft = CalcWeight(param, left); - double wright = CalcWeight(param, right); - double gain = - CalcGainGivenWeight(param, left.sum_grad, left.sum_hess, wleft) + - CalcGainGivenWeight(param, right.sum_grad, right.sum_hess, wright); - if (constraint == 0) { - return gain; - } else if (constraint > 0) { - return wleft <= wright ? gain : negative_infinity; - } else { - return wleft >= wright ? gain : negative_infinity; - } - } - - inline void SetChild(const TrainParam ¶m, bst_uint split_index, - GradStats left, GradStats right, ValueConstraint *cleft, - ValueConstraint *cright) { - int c = param.monotone_constraints.at(split_index); - *cleft = *this; - *cright = *this; - if (c == 0) { - return; - } - double wleft = CalcWeight(param, left); - double wright = CalcWeight(param, right); - double mid = (wleft + wright) / 2; - CHECK(!std::isnan(mid)); - if (c < 0) { - cleft->lower_bound = mid; - cright->upper_bound = mid; - } else { - cleft->upper_bound = mid; - cright->lower_bound = mid; - } - } -}; - /*! * \brief statistics that is helpful to store * and represent a split solution for the tree diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index f177a1162..d32901311 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -69,7 +69,7 @@ struct GPUTrainingParam { max_delta_step(param.max_delta_step) {} }; -using NodeIdT = int; +using NodeIdT = int32_t; /** used to assign default id to a Node */ static const int kUnusedNode = -1; @@ -88,8 +88,9 @@ enum DefaultDirection { struct DeviceSplitCandidate { float loss_chg; DefaultDirection dir; - float fvalue; int findex; + float fvalue; + GradientPair left_sum; GradientPair right_sum; @@ -107,10 +108,10 @@ struct DeviceSplitCandidate { } XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, - float fvalue_in, int findex_in, - GradientPair left_sum_in, - GradientPair right_sum_in, - const GPUTrainingParam& param) { + float fvalue_in, int findex_in, + GradientPair left_sum_in, + GradientPair right_sum_in, + const GPUTrainingParam& param) { if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight && right_sum_in.GetHess() >= param.min_child_weight) { @@ -214,76 +215,14 @@ struct SumCallbackOp { } }; -template -XGBOOST_DEVICE inline float DeviceCalcLossChange(const GPUTrainingParam& param, - const GradientPairT& left, - const GradientPairT& parent_sum, - const float& parent_gain) { - GradientPairT right = parent_sum - left; - float left_gain = CalcGain(param, left.GetGrad(), left.GetHess()); - float right_gain = CalcGain(param, right.GetGrad(), right.GetHess()); - return left_gain + right_gain - parent_gain; -} - // Total number of nodes in tree, given depth XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { return (1 << (depth + 1)) - 1; } -// Number of nodes at this level of the tree -XGBOOST_DEVICE inline int MaxNodesLevel(int depth) { return 1 << depth; } - -// Whether a node is currently being processed at current depth -XGBOOST_DEVICE inline bool IsNodeActive(int nidx, int depth) { - return nidx >= MaxNodesDepth(depth - 1); -} - -XGBOOST_DEVICE inline int ParentNodeIdx(int nidx) { return (nidx - 1) / 2; } - -XGBOOST_DEVICE inline int LeftChildNodeIdx(int nidx) { - return nidx * 2 + 1; -} - -XGBOOST_DEVICE inline int RightChildNodeIdx(int nidx) { - return nidx * 2 + 2; -} - -XGBOOST_DEVICE inline bool IsLeftChild(int nidx) { - return nidx % 2 == 1; -} - -// Copy gpu dense representation of tree to xgboost sparse representation -inline void Dense2SparseTree(RegTree* p_tree, - common::Span nodes, - const TrainParam& param) { - RegTree& tree = *p_tree; - std::vector h_nodes(nodes.size()); - dh::safe_cuda(cudaMemcpy(h_nodes.data(), nodes.data(), - nodes.size() * sizeof(DeviceNodeStats), - cudaMemcpyDeviceToHost)); - - int nid = 0; - for (int gpu_nid = 0; gpu_nid < h_nodes.size(); gpu_nid++) { - const DeviceNodeStats& n = h_nodes[gpu_nid]; - if (!n.IsUnused() && !n.IsLeaf()) { - tree.ExpandNode(nid, n.fidx, n.fvalue, n.dir == kLeftDir, n.weight, 0.0f, - 0.0f, n.root_gain, n.sum_gradients.GetHess()); - tree.Stat(nid).loss_chg = n.root_gain; - tree.Stat(nid).base_weight = n.weight; - tree.Stat(nid).sum_hess = n.sum_gradients.GetHess(); - nid++; - } else if (n.IsLeaf()) { - tree[nid].SetLeaf(n.weight * param.learning_rate); - tree.Stat(nid).sum_hess = n.sum_gradients.GetHess(); - nid++; - } - } -} - /* * Random */ - struct BernoulliRng { float p; uint32_t seed;