Clean up training code. (#3825)
* Remove GHistRow, GHistEntry, GHistIndexRow. * Remove kSimpleStats. * Remove CheckInfo, SetLeafVec in GradStats and in SKStats. * Clean up the GradStats. * Cleanup calcgain. * Move LossChangeMissing out of common. * Remove [] operator from GHistIndexBlock.
This commit is contained in:
@@ -50,6 +50,28 @@ struct GPUHistMakerTrainParam
|
||||
|
||||
DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam);
|
||||
|
||||
// With constraints
|
||||
template <typename GradientPairT>
|
||||
XGBOOST_DEVICE float inline LossChangeMissing(
|
||||
const GradientPairT& scan, const GradientPairT& missing, const GradientPairT& parent_sum,
|
||||
const float& parent_gain, const GPUTrainingParam& param, int constraint,
|
||||
const ValueConstraint& value_constraint,
|
||||
bool& missing_left_out) { // NOLINT
|
||||
float missing_left_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan + missing),
|
||||
GradStats(parent_sum - (scan + missing)));
|
||||
float missing_right_gain = value_constraint.CalcSplitGain(
|
||||
param, constraint, GradStats(scan), GradStats(parent_sum - scan));
|
||||
|
||||
if (missing_left_gain >= missing_right_gain) {
|
||||
missing_left_out = true;
|
||||
return missing_left_gain - parent_gain;
|
||||
} else {
|
||||
missing_left_out = false;
|
||||
return missing_right_gain - parent_gain;
|
||||
}
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief
|
||||
*
|
||||
@@ -942,7 +964,6 @@ class GPUHistMakerSpecialised{
|
||||
void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* dmat,
|
||||
const std::vector<RegTree*>& trees) {
|
||||
monitor_.Start("Update", dist_.Devices());
|
||||
GradStats::CheckInfo(dmat->Info());
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
@@ -1183,11 +1204,12 @@ class GPUHistMakerSpecialised{
|
||||
|
||||
void ApplySplit(const ExpandEntry& candidate, RegTree* p_tree) {
|
||||
RegTree& tree = *p_tree;
|
||||
GradStats left_stats(param_);
|
||||
|
||||
GradStats left_stats;
|
||||
left_stats.Add(candidate.split.left_sum);
|
||||
GradStats right_stats(param_);
|
||||
GradStats right_stats;
|
||||
right_stats.Add(candidate.split.right_sum);
|
||||
GradStats parent_sum(param_);
|
||||
GradStats parent_sum;
|
||||
parent_sum.Add(left_stats);
|
||||
parent_sum.Add(right_stats);
|
||||
node_value_constraints_.resize(tree.GetNodes().size());
|
||||
|
||||
Reference in New Issue
Block a user