xgboost/src/objective/adaptive.h
Jiaming Yuan fffb1fca52
Calculate base_score based on input labels for mae. (#8107)
Fit an intercept as base score for abs loss.
2022-09-20 20:53:54 +08:00

85 lines
2.9 KiB
C++

/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once
#include <algorithm>
#include <limits>
#include <vector>
#include "../common/common.h"
#include "rabit/rabit.h"
#include "xgboost/generic_parameters.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace obj {
namespace detail {
inline void FillMissingLeaf(std::vector<bst_node_t> const& maybe_missing,
std::vector<bst_node_t>* p_nidx, std::vector<size_t>* p_nptr) {
auto& h_node_idx = *p_nidx;
auto& h_node_ptr = *p_nptr;
for (auto leaf : maybe_missing) {
if (std::binary_search(h_node_idx.cbegin(), h_node_idx.cend(), leaf)) {
continue;
}
auto it = std::upper_bound(h_node_idx.cbegin(), h_node_idx.cend(), leaf);
auto pos = it - h_node_idx.cbegin();
h_node_idx.insert(h_node_idx.cbegin() + pos, leaf);
h_node_ptr.insert(h_node_ptr.cbegin() + pos, h_node_ptr[pos]);
}
}
inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_node_t> const nidx,
RegTree* p_tree) {
auto& tree = *p_tree;
auto& quantiles = *p_quantiles;
auto const& h_node_idx = nidx;
size_t n_leaf{h_node_idx.size()};
rabit::Allreduce<rabit::op::Max>(&n_leaf, 1);
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
if (quantiles.empty()) {
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
}
// number of workers that have valid quantiles
std::vector<int32_t> n_valids(quantiles.size());
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
rabit::Allreduce<rabit::op::Sum>(n_valids.data(), n_valids.size());
// convert to 0 for all reduce
std::replace_if(
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
// use the mean value
rabit::Allreduce<rabit::op::Sum>(quantiles.data(), quantiles.size());
for (size_t i = 0; i < n_leaf; ++i) {
if (n_valids[i] > 0) {
quantiles[i] /= static_cast<float>(n_valids[i]);
} else {
// Use original leaf value if no worker can provide the quantile.
quantiles[i] = tree[h_node_idx[i]].LeafValue();
}
}
for (size_t i = 0; i < nidx.size(); ++i) {
auto nidx = h_node_idx[i];
auto q = quantiles[i];
CHECK(tree[nidx].IsLeaf());
tree[nidx].SetLeaf(q);
}
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree);
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree);
} // namespace detail
} // namespace obj
} // namespace xgboost