Extract CPUExpandEntry and HistParam. (#7321)
* Remove kRootNid. * Check for empty hessian.
This commit is contained in:
parent
6cdcfe8128
commit
8e619010d0
64
src/tree/hist/expand_entry.h
Normal file
64
src/tree/hist/expand_entry.h
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
|
#define XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
#include "../param.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
|
||||||
|
struct CPUExpandEntry {
|
||||||
|
int nid;
|
||||||
|
int depth;
|
||||||
|
SplitEntry split;
|
||||||
|
CPUExpandEntry() = default;
|
||||||
|
XGBOOST_DEVICE
|
||||||
|
CPUExpandEntry(int nid, int depth, SplitEntry split)
|
||||||
|
: nid(nid), depth(depth), split(std::move(split)) {}
|
||||||
|
CPUExpandEntry(int nid, int depth, float loss_chg)
|
||||||
|
: nid(nid), depth(depth) {
|
||||||
|
split.loss_chg = loss_chg;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsValid(const TrainParam& param, int num_leaves) const {
|
||||||
|
if (split.loss_chg <= kRtEps) return false;
|
||||||
|
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (split.loss_chg < param.min_split_loss) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (param.max_depth > 0 && depth == param.max_depth) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
float GetLossChange() const { return split.loss_chg; }
|
||||||
|
bst_node_t GetNodeId() const { return nid; }
|
||||||
|
|
||||||
|
static bool ChildIsValid(const TrainParam& param, int depth, int num_leaves) {
|
||||||
|
if (param.max_depth > 0 && depth >= param.max_depth) return false;
|
||||||
|
if (param.max_leaves > 0 && num_leaves >= param.max_leaves) return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
|
||||||
|
os << "ExpandEntry: \n";
|
||||||
|
os << "nidx: " << e.nid << "\n";
|
||||||
|
os << "depth: " << e.depth << "\n";
|
||||||
|
os << "loss: " << e.split.loss_chg << "\n";
|
||||||
|
os << "left_sum: " << e.split.left_sum << "\n";
|
||||||
|
os << "right_sum: " << e.split.right_sum << "\n";
|
||||||
|
return os;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // XGBOOST_TREE_HIST_EXPAND_ENTRY_H_
|
||||||
23
src/tree/hist/param.h
Normal file
23
src/tree/hist/param.h
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
/*!
|
||||||
|
* Copyright 2021 XGBoost contributors
|
||||||
|
*/
|
||||||
|
#ifndef XGBOOST_TREE_HIST_PARAM_H_
|
||||||
|
#define XGBOOST_TREE_HIST_PARAM_H_
|
||||||
|
#include "xgboost/parameter.h"
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace tree {
|
||||||
|
// training parameters specific to this algorithm
|
||||||
|
struct CPUHistMakerTrainParam
|
||||||
|
: public XGBoostParameter<CPUHistMakerTrainParam> {
|
||||||
|
bool single_precision_histogram;
|
||||||
|
// declare parameters
|
||||||
|
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
|
||||||
|
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
||||||
|
"Use single precision to build histograms.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace tree
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
#endif // XGBOOST_TREE_HIST_PARAM_H_
|
||||||
@ -124,7 +124,7 @@ template <bool any_missing>
|
|||||||
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
||||||
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
|
DMatrix *p_fmat, RegTree *p_tree, const std::vector<GradientPair> &gpair_h,
|
||||||
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
|
int *num_leaves, std::vector<CPUExpandEntry> *expand) {
|
||||||
CPUExpandEntry node(CPUExpandEntry::kRootNid, p_tree->GetDepth(0), 0.0f);
|
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f);
|
||||||
|
|
||||||
nodes_for_explicit_hist_build_.clear();
|
nodes_for_explicit_hist_build_.clear();
|
||||||
nodes_for_subtraction_trick_.clear();
|
nodes_for_subtraction_trick_.clear();
|
||||||
@ -135,7 +135,7 @@ void QuantileHistMaker::Builder<GradientSumT>::InitRoot(
|
|||||||
nodes_for_subtraction_trick_, gpair_h);
|
nodes_for_subtraction_trick_, gpair_h);
|
||||||
|
|
||||||
{
|
{
|
||||||
auto nid = CPUExpandEntry::kRootNid;
|
auto nid = RegTree::kRoot;
|
||||||
GHistRowT hist = this->histogram_builder_->Histogram()[nid];
|
GHistRowT hist = this->histogram_builder_->Histogram()[nid];
|
||||||
GradientPairT grad_stat;
|
GradientPairT grad_stat;
|
||||||
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
|
if (data_layout_ == DataLayout::kDenseDataZeroBased ||
|
||||||
|
|||||||
@ -23,6 +23,9 @@
|
|||||||
|
|
||||||
#include "hist/evaluate_splits.h"
|
#include "hist/evaluate_splits.h"
|
||||||
#include "hist/histogram.h"
|
#include "hist/histogram.h"
|
||||||
|
#include "hist/expand_entry.h"
|
||||||
|
#include "hist/param.h"
|
||||||
|
|
||||||
#include "constraints.h"
|
#include "constraints.h"
|
||||||
#include "./param.h"
|
#include "./param.h"
|
||||||
#include "./driver.h"
|
#include "./driver.h"
|
||||||
@ -89,51 +92,6 @@ using xgboost::common::GHistBuilder;
|
|||||||
using xgboost::common::ColumnMatrix;
|
using xgboost::common::ColumnMatrix;
|
||||||
using xgboost::common::Column;
|
using xgboost::common::Column;
|
||||||
|
|
||||||
// training parameters specific to this algorithm
|
|
||||||
struct CPUHistMakerTrainParam
|
|
||||||
: public XGBoostParameter<CPUHistMakerTrainParam> {
|
|
||||||
bool single_precision_histogram = false;
|
|
||||||
// declare parameters
|
|
||||||
DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) {
|
|
||||||
DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe(
|
|
||||||
"Use single precision to build histograms.");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/* tree growing policies */
|
|
||||||
struct CPUExpandEntry {
|
|
||||||
static const int kRootNid = 0;
|
|
||||||
static const int kEmptyNid = -1;
|
|
||||||
int nid;
|
|
||||||
int depth;
|
|
||||||
SplitEntry split;
|
|
||||||
|
|
||||||
CPUExpandEntry() = default;
|
|
||||||
CPUExpandEntry(int nid, int depth, bst_float loss_chg)
|
|
||||||
: nid(nid), depth(depth) {
|
|
||||||
split.loss_chg = loss_chg;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool IsValid(TrainParam const ¶m, int32_t num_leaves) const {
|
|
||||||
bool invalid = split.loss_chg <= kRtEps ||
|
|
||||||
(param.max_depth > 0 && this->depth == param.max_depth) ||
|
|
||||||
(param.max_leaves > 0 && num_leaves == param.max_leaves);
|
|
||||||
return !invalid;
|
|
||||||
}
|
|
||||||
|
|
||||||
bst_float GetLossChange() const {
|
|
||||||
return split.loss_chg;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetNodeId() const {
|
|
||||||
return nid;
|
|
||||||
}
|
|
||||||
|
|
||||||
int GetDepth() const {
|
|
||||||
return depth;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
class QuantileHistMaker: public TreeUpdater {
|
class QuantileHistMaker: public TreeUpdater {
|
||||||
public:
|
public:
|
||||||
|
|||||||
@ -258,7 +258,7 @@ void TestBuildHistogram(bool is_distributed) {
|
|||||||
std::iota(row_indices.begin(), row_indices.end(), 0);
|
std::iota(row_indices.begin(), row_indices.end(), 0);
|
||||||
row_set_collection_.Init();
|
row_set_collection_.Init();
|
||||||
|
|
||||||
CPUExpandEntry node(CPUExpandEntry::kRootNid, tree.GetDepth(0), 0.0f);
|
CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f);
|
||||||
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
std::vector<CPUExpandEntry> nodes_for_explicit_hist_build_;
|
||||||
nodes_for_explicit_hist_build_.push_back(node);
|
nodes_for_explicit_hist_build_.push_back(node);
|
||||||
histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_,
|
histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user