Extract CPUExpandEntry and HistParam. (#7321)
* Remove kRootNid. * Check for empty hessian.
This commit is contained in:
parent
6cdcfe8128
commit
8e619010d0
64
src/tree/hist/expand_entry.h
Normal file
64
src/tree/hist/expand_entry.h
Normal file
@ -0,0 +1,64 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||
|
||||
#include <utility>
|
||||
#include "../param.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
|
||||
struct CPUExpandEntry {
|
||||
int nid;
|
||||
int depth;
|
||||
SplitEntry split;
|
||||
CPUExpandEntry() = default;
|
||||
XGBOOST_DEVICE
|
||||
CPUExpandEntry(int nid, int depth, SplitEntry split)
|
||||
: nid(nid), depth(depth), split(std::move(split)) {}
|
||||
CPUExpandEntry(int nid, int depth, float loss_chg)
|
||||
: nid(nid), depth(depth) {
|
||||
split.loss_chg = loss_chg;
|
||||
}
|
||||
|
||||
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||
if (split.loss_chg <= kRtEps) return false;
|
||||
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||
return false;
|
||||
}
|
||||
if (split.loss_chg < param.min_split_loss) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||
return false;
|
||||
}
|
||||
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
float GetLossChange() const { return split.loss_chg; }
|
||||
bst_node_t GetNodeId() const { return nid; }
|
||||
|
||||
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
|
||||
os << "ExpandEntry: \n";
|
||||
os << "nidx: " << e.nid << "\n";
|
||||
os << "depth: " << e.depth << "\n";
|
||||
os << "loss: " << e.split.loss_chg << "\n";
|
||||
os << "left_sum: " << e.split.left_sum << "\n";
|
||||
os << "right_sum: " << e.split.right_sum << "\n";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||
23
src/tree/hist/param.h
Normal file
23
src/tree/hist/param.h
Normal file
@ -0,0 +1,23 @@
|
||||
/*!
|
||||
* Copyright 2021 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TREE_HIST_PARAM_H_
|
||||
#define XGBOOST_TREE_HIST_PARAM_H_
|
||||
#include "xgboost/parameter.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace tree {
|
||||
// training parameters specific to this algorithm
|
||||
struct CPUHistMakerTrainParam
|
||||
: public XGBoostParameter<CPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||
"Use single precision to build histograms.");
|
||||
}
|
||||
};
|
||||
} // namespace tree
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TREE_HIST_PARAM_H_
|
||||
@ -124,7 +124,7 @@ template <bool any_missing>
|
||||
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
|
||||
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f);
|
||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);
|
||||
|
||||
nodes_for_explicit_hist_build_.clear();
|
||||
nodes_for_subtraction_trick_.clear();
|
||||
@ -135,7 +135,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||
nodes_for_subtraction_trick_, gpair_h);
|
||||
|
||||
{
|
||||
auto nid = CPUExpandEntry::kRootNid;
|
||||
auto nid = RegTree::kRoot;
|
||||
GHistRowT hist = this->histogram_builder_->Histogram()[nid];
|
||||
GradientPairT grad_stat;
|
||||
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
|
||||
|
||||
@ -23,6 +23,9 @@
|
||||
|
||||
#include "hist/evaluate_splits.h"
|
||||
#include "hist/histogram.h"
|
||||
#include "hist/expand_entry.h"
|
||||
#include "hist/param.h"
|
||||
|
||||
#include "constraints.h"
|
||||
#include "./param.h"
|
||||
#include "./driver.h"
|
||||
@ -89,51 +92,6 @@ using xgboost::common::GHistBuilder;
|
||||
using xgboost::common::ColumnMatrix;
|
||||
using xgboost::common::Column;
|
||||
|
||||
// training parameters specific to this algorithm
|
||||
struct CPUHistMakerTrainParam
|
||||
: public XGBoostParameter<CPUHistMakerTrainParam> {
|
||||
bool single_precision_histogram = false;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
|
||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||
"Use single precision to build histograms.");
|
||||
}
|
||||
};
|
||||
|
||||
/* tree growing policies */
|
||||
struct CPUExpandEntry {
|
||||
static const int kRootNid = 0;
|
||||
static const int kEmptyNid = -1;
|
||||
int nid;
|
||||
int depth;
|
||||
SplitEntry split;
|
||||
|
||||
CPUExpandEntry() = default;
|
||||
CPUExpandEntry(int nid, int depth, bst_float loss_chg)
|
||||
: nid(nid), depth(depth) {
|
||||
split.loss_chg = loss_chg;
|
||||
}
|
||||
|
||||
bool IsValid(TrainParam const ¶m, int32_t num_leaves) const {
|
||||
bool invalid = split.loss_chg <= kRtEps ||
|
||||
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
||||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
||||
return !invalid;
|
||||
}
|
||||
|
||||
bst_float GetLossChange() const {
|
||||
return split.loss_chg;
|
||||
}
|
||||
|
||||
int GetNodeId() const {
|
||||
return nid;
|
||||
}
|
||||
|
||||
int GetDepth() const {
|
||||
return depth;
|
||||
}
|
||||
};
|
||||
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker: public TreeUpdater {
|
||||
public:
|
||||
|
||||
@ -258,7 +258,7 @@ void TestBuildHistogram(bool is_distributed) {
|
||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||
row_set_collection_.Init();
|
||||
|
||||
CPUExpandEntry node(CPUExpandEntry::kRootNid, tree.GetDepth(0), 0.0f);
|
||||
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
|
||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||
nodes_for_explicit_hist_build_.push_back(node);
|
||||
histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user