Clean up training code. (#3825)

* Remove GHistRow, GHistEntry, GHistIndexRow.
* Remove kSimpleStats.
* Remove CheckInfo, SetLeafVec in GradStats and in SKStats.
* Clean up the GradStats.
* Cleanup calcgain.
* Move LossChangeMissing out of common.
* Remove [] operator from GHistIndexBlock.
This commit is contained in:
Jiaming Yuan
2019-02-07 14:22:13 +08:00
committed by GitHub
parent 325b16bccd
commit 017c97b8ce
19 changed files with 306 additions and 406 deletions

View File

@@ -50,6 +50,28 @@ struct GPUHistMakerTrainParam
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
// With constraints
template <typename GradientPairT>
XGBOOST_DEVICE float inline LossChangeMissing(
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
const float& parent_gain, const GPUTrainingParam& param, int constraint,
const ValueConstraint& value_constraint,
bool& missing_left_out) { // NOLINT
float missing_left_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan + missing),
GradStats(parent_sum - (scan + missing)));
float missing_right_gain = value_constraint.CalcSplitGain(
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
if (missing_left_gain >= missing_right_gain) {
missing_left_out = true;
return missing_left_gain - parent_gain;
} else {
missing_left_out = false;
return missing_right_gain - parent_gain;
}
}
/*!
* \brief
*
@@ -942,7 +964,6 @@ class GPUHistMakerSpecialised{
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
const std::vector<RegTree*>& trees) {
monitor_.Start("Update", dist_.Devices());
GradStats::CheckInfo(dmat->Info());
// rescale learning rate according to size of trees
float lr = param_.learning_rate;
param_.learning_rate = lr / trees.size();
@@ -1183,11 +1204,12 @@ class GPUHistMakerSpecialised{
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
RegTree& tree = *p_tree;
GradStats left_stats(param_);
GradStats left_stats;
left_stats.Add(candidate.split.left_sum);
GradStats right_stats(param_);
GradStats right_stats;
right_stats.Add(candidate.split.right_sum);
GradStats parent_sum(param_);
GradStats parent_sum;
parent_sum.Add(left_stats);
parent_sum.Add(right_stats);
node_value_constraints_.resize(tree.GetNodes().size());