Initial support for quantile loss. (#8750)

- Add support for Python.
- Add objective.
This commit is contained in:
Jiaming Yuan
2023-02-16 02:30:18 +08:00
committed by GitHub
parent 282b1729da
commit cce4af4acf
26 changed files with 701 additions and 70 deletions

View File

@@ -3,17 +3,25 @@
*/
#include "adaptive.h"
#include <limits>
#include <vector>
#include <algorithm> // std::transform,std::find_if,std::copy,std::unique
#include <cmath> // std::isnan
#include <cstddef> // std::size_t
#include <iterator> // std::distance
#include <vector> // std::vector
#include "../common/algorithm.h" // ArgSort
#include "../common/common.h" // AssertGPUSupport
#include "../common/numeric.h" // RunLengthEncode
#include "../common/stats.h" // Quantile,WeightedQuantile
#include "../common/threading_utils.h" // ParallelFor
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/base.h" // bst_node_t
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h"
#include "xgboost/tree_model.h"
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // MakeTensorView
#include "xgboost/span.h" // Span
#include "xgboost/tree_model.h" // RegTree
namespace xgboost {
namespace obj {
@@ -100,8 +108,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
CHECK_LE(group_idx, info.labels.Shape(1));
auto h_labels = info.labels.HostView().Slice(linalg::All(), group_idx);
auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_);
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
@@ -115,9 +123,9 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(alpha, iter, iter + h_row_set.size());
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it);
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
@@ -127,6 +135,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
UpdateLeafValues(&quantiles, nidx, p_tree);
}
#if !defined(XGBOOST_USE_CUDA)
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
MetaInfo const&, HostDeviceVector<float> const&, float, RegTree*) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace detail
} // namespace obj
} // namespace xgboost