Multi-target support for L1 error. (#8652)

- Add matrix support to the median function.
- Iterate through each target for quantile computation.
This commit is contained in:
Jiaming Yuan
2023-01-11 05:51:14 +08:00
committed by GitHub
parent badeff1d74
commit cfa994d57f
19 changed files with 430 additions and 215 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2022 by XGBoost Contributors
/**
* Copyright 2022-2023 by XGBoost Contributors
*/
#include "adaptive.h"
@@ -11,6 +11,7 @@
#include "../common/stats.h"
#include "../common/threading_utils.h"
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/linalg.h"
#include "xgboost/tree_model.h"
namespace xgboost {
@@ -66,8 +67,8 @@ void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& posi
}
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree) {
std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
auto& tree = *p_tree;
std::vector<bst_node_t> nidx;
@@ -88,6 +89,9 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
auto const& h_node_idx = nidx;
auto const& h_node_ptr = nptr;
CHECK_LE(h_node_ptr.back(), info.num_row_);
auto h_predt = linalg::MakeTensorView(predt.ConstHostSpan(),
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
@@ -95,14 +99,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
// multi-target not yet supported.
auto h_labels = info.labels.HostView().Slice(linalg::All(), 0);
auto const& h_predt = predt.ConstHostVector();
CHECK_LE(group_idx, info.labels.Shape(1));
auto h_labels = info.labels.HostView().Slice(linalg::All(), group_idx);
auto h_weights = linalg::MakeVec(&info.weights_);
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt[row_idx];
return h_labels(row_idx) - h_predt(row_idx, group_idx);
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];