Fix gain calculation in multi-target tree. (#9978)
This commit is contained in:
@@ -730,6 +730,9 @@ class HistMultiEvaluator {
|
||||
|
||||
std::size_t n_nodes = p_tree->Size();
|
||||
gain_.resize(n_nodes);
|
||||
// Re-calculate weight without learning rate.
|
||||
CalcWeight(*param_, left_sum, left_weight);
|
||||
CalcWeight(*param_, right_sum, right_weight);
|
||||
gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
|
||||
gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);
|
||||
|
||||
|
||||
@@ -195,8 +195,9 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
|
||||
split_index_.resize(n);
|
||||
split_index_[nidx] = split_idx;
|
||||
|
||||
split_conds_.resize(n);
|
||||
split_conds_.resize(n, std::numeric_limits<float>::quiet_NaN());
|
||||
split_conds_[nidx] = split_cond;
|
||||
|
||||
default_left_.resize(n);
|
||||
default_left_[nidx] = static_cast<std::uint8_t>(default_left);
|
||||
|
||||
|
||||
@@ -149,6 +149,9 @@ class MultiTargetHistBuilder {
|
||||
}
|
||||
|
||||
void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
|
||||
if (collective::IsDistributed()) {
|
||||
LOG(FATAL) << "Distributed training for vector-leaf is not yet supported.";
|
||||
}
|
||||
monitor_->Start(__func__);
|
||||
|
||||
p_last_fmat_ = p_fmat;
|
||||
|
||||
Reference in New Issue
Block a user