Initial support for quantile loss. (#8750)

- Add support for Python.
- Add objective.
This commit is contained in:
Jiaming Yuan
2023-02-16 02:30:18 +08:00
committed by GitHub
parent 282b1729da
commit cce4af4acf
26 changed files with 701 additions and 70 deletions

View File

@@ -3,17 +3,25 @@
*/
#include "adaptive.h"
#include <limits>
#include <vector>
#include <algorithm> // std::transform,std::find_if,std::copy,std::unique
#include <cmath> // std::isnan
#include <cstddef> // std::size_t
#include <iterator> // std::distance
#include <vector> // std::vector
#include "../common/algorithm.h" // ArgSort
#include "../common/common.h" // AssertGPUSupport
#include "../common/numeric.h" // RunLengthEncode
#include "../common/stats.h" // Quantile,WeightedQuantile
#include "../common/threading_utils.h" // ParallelFor
#include "../common/transform_iterator.h" // MakeIndexTransformIter
#include "xgboost/base.h" // bst_node_t
#include "xgboost/context.h" // Context
#include "xgboost/linalg.h"
#include "xgboost/tree_model.h"
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/linalg.h" // MakeTensorView
#include "xgboost/span.h" // Span
#include "xgboost/tree_model.h" // RegTree
namespace xgboost {
namespace obj {
@@ -100,8 +108,8 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
CHECK_LE(group_idx, info.labels.Shape(1));
auto h_labels = info.labels.HostView().Slice(linalg::All(), group_idx);
auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_);
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
@@ -115,9 +123,9 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(alpha, iter, iter + h_row_set.size());
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it);
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
@@ -127,6 +135,13 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
UpdateLeafValues(&quantiles, nidx, p_tree);
}
#if !defined(XGBOOST_USE_CUDA)
void UpdateTreeLeafDevice(Context const*, common::Span<bst_node_t const>, std::int32_t,
MetaInfo const&, HostDeviceVector<float> const&, float, RegTree*) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace detail
} // namespace obj
} // namespace xgboost

View File

@@ -20,20 +20,19 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
HostDeviceVector<bst_node_t>* p_nidx, RegTree const& tree) {
// copy position to buffer
dh::safe_cuda(cudaSetDevice(ctx->gpu_id));
auto cuctx = ctx->CUDACtx();
size_t n_samples = position.size();
dh::XGBDeviceAllocator<char> alloc;
dh::device_vector<bst_node_t> sorted_position(position.size());
dh::safe_cuda(cudaMemcpyAsync(sorted_position.data().get(), position.data(),
position.size_bytes(), cudaMemcpyDeviceToDevice));
position.size_bytes(), cudaMemcpyDeviceToDevice, cuctx->Stream()));
p_ridx->resize(position.size());
dh::Iota(dh::ToSpan(*p_ridx));
// sort row index according to node index
thrust::stable_sort_by_key(thrust::cuda::par(alloc), sorted_position.begin(),
thrust::stable_sort_by_key(cuctx->TP(), sorted_position.begin(),
sorted_position.begin() + n_samples, p_ridx->begin());
dh::XGBCachingDeviceAllocator<char> caching;
size_t beg_pos =
thrust::find_if(thrust::cuda::par(caching), sorted_position.cbegin(), sorted_position.cend(),
thrust::find_if(cuctx->CTP(), sorted_position.cbegin(), sorted_position.cend(),
[] XGBOOST_DEVICE(bst_node_t nidx) { return nidx >= 0; }) -
sorted_position.cbegin();
if (beg_pos == sorted_position.size()) {
@@ -72,7 +71,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
size_t* h_num_runs = reinterpret_cast<size_t*>(pinned.subspan(0, sizeof(size_t)).data());
dh::CUDAEvent e;
e.Record(dh::DefaultStream());
e.Record(cuctx->Stream());
copy_stream.View().Wait(e);
// flag for whether there's ignored position
bst_node_t* h_first_unique =
@@ -108,7 +107,7 @@ void EncodeTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
d_node_ptr[0] = beg_pos;
}
});
thrust::inclusive_scan(thrust::cuda::par(caching), dh::tbegin(d_node_ptr), dh::tend(d_node_ptr),
thrust::inclusive_scan(cuctx->CTP(), dh::tbegin(d_node_ptr), dh::tend(d_node_ptr),
dh::tbegin(d_node_ptr));
copy_stream.View().Sync();
CHECK_GT(*h_num_runs, 0);
@@ -162,7 +161,7 @@ void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> pos
{info.num_row_, predt.Size() / info.num_row_}, ctx->gpu_id);
CHECK_LT(group_idx, d_predt.Shape(1));
auto t_predt = d_predt.Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), group_idx);
auto d_labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), IdxY(info, group_idx));
auto d_row_index = dh::ToSpan(ridx);
auto seg_beg = nptr.DevicePointer();

View File

@@ -6,13 +6,15 @@
#include <algorithm>
#include <cstdint> // std::int32_t
#include <limits>
#include <vector>
#include <vector> // std::vector
#include "../collective/communicator-inl.h"
#include "../common/common.h"
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/tree_model.h"
#include "xgboost/base.h" // bst_node_t
#include "xgboost/context.h" // Context
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/tree_model.h" // RegTree
namespace xgboost {
namespace obj {
@@ -73,6 +75,15 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
}
}
inline std::size_t IdxY(MetaInfo const& info, bst_group_t group_idx) {
std::size_t y_idx{0};
if (info.labels.Shape(1) > 1) {
y_idx = group_idx;
}
CHECK_LE(y_idx, info.labels.Shape(1));
return y_idx;
}
void UpdateTreeLeafDevice(Context const* ctx, common::Span<bst_node_t const> position,
std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
@@ -81,5 +92,18 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree);
} // namespace detail
inline void UpdateTreeLeaf(Context const* ctx, HostDeviceVector<bst_node_t> const& position,
std::int32_t group_idx, MetaInfo const& info,
HostDeviceVector<float> const& predt, float alpha, RegTree* p_tree) {
if (ctx->IsCPU()) {
detail::UpdateTreeLeafHost(ctx, position.ConstHostVector(), group_idx, info, predt, alpha,
p_tree);
} else {
position.SetDevice(ctx->gpu_id);
detail::UpdateTreeLeafDevice(ctx, position.ConstDeviceSpan(), group_idx, info, predt, alpha,
p_tree);
}
}
} // namespace obj
} // namespace xgboost

View File

@@ -44,11 +44,13 @@ namespace obj {
// List of files that will be force linked in static links.
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu);
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
DMLC_REGISTRY_LINK_TAG(rank_obj_gpu);
#else
DMLC_REGISTRY_LINK_TAG(regression_obj);
DMLC_REGISTRY_LINK_TAG(quantile_obj);
DMLC_REGISTRY_LINK_TAG(hinge_obj);
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
DMLC_REGISTRY_LINK_TAG(rank_obj);

View File

@@ -0,0 +1,18 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
// Dummy file to enable the CUDA conditional compile trick.
#include <dmlc/registry.h>
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(quantile_obj);
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#include "quantile_obj.cu"
#endif // !defined(XBGOOST_USE_CUDA)

View File

@@ -0,0 +1,226 @@
/**
* Copyright 2023 by XGBoost contributors
*/
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <vector> // std::vector
#include "../common/linalg_op.h" // ElementWiseKernel,cbegin,cend
#include "../common/quantile_loss_utils.h" // QuantileLossParam
#include "../common/stats.h" // Quantile,WeightedQuantile
#include "adaptive.h" // UpdateTreeLeaf
#include "dmlc/parameter.h" // DMLC_DECLARE_PARAMETER
#include "init_estimation.h" // CheckInitInputs
#include "xgboost/base.h" // GradientPair,XGBOOST_DEVICE,bst_target_t
#include "xgboost/data.h" // MetaInfo
#include "xgboost/host_device_vector.h" // HostDeviceVector
#include "xgboost/json.h" // Json,String,ToJson,FromJson
#include "xgboost/linalg.h" // Tensor,MakeTensorView,MakeVec
#include "xgboost/objective.h" // ObjFunction
#include "xgboost/parameter.h" // XGBoostParameter
#if defined(XGBOOST_USE_CUDA)
#include "../common/linalg_op.cuh" // ElementWiseKernel
#include "../common/stats.cuh" // SegmentedQuantile
#endif // defined(XGBOOST_USE_CUDA)
namespace xgboost {
namespace obj {
class QuantileRegression : public ObjFunction {
common::QuantileLossParam param_;
HostDeviceVector<float> alpha_;
bst_target_t Targets(MetaInfo const& info) const override {
auto const& alpha = param_.quantile_alpha.Get();
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target is not yet supported by the quantile loss.";
CHECK(!alpha.empty());
// We have some placeholders for multi-target in the quantile loss. But it's not
// supported as the gbtree doesn't know how to slice the gradient and there's no 3-dim
// model shape in general.
auto n_y = std::max(static_cast<std::size_t>(1), info.labels.Shape(1));
return alpha_.Size() * n_y;
}
public:
void GetGradient(HostDeviceVector<float> const& preds, const MetaInfo& info, std::int32_t iter,
HostDeviceVector<GradientPair>* out_gpair) override {
if (iter == 0) {
CheckInitInputs(info);
}
CHECK_EQ(param_.quantile_alpha.Get().size(), alpha_.Size());
using SizeT = decltype(info.num_row_);
SizeT n_targets = this->Targets(info);
SizeT n_alphas = alpha_.Size();
CHECK_NE(n_alphas, 0);
CHECK_GE(n_targets, n_alphas);
CHECK_EQ(preds.Size(), info.num_row_ * n_targets);
auto labels = info.labels.View(ctx_->gpu_id);
out_gpair->SetDevice(ctx_->gpu_id);
out_gpair->Resize(n_targets * info.num_row_);
auto gpair =
linalg::MakeTensorView(ctx_->IsCPU() ? out_gpair->HostSpan() : out_gpair->DeviceSpan(),
{info.num_row_, n_alphas, n_targets / n_alphas}, ctx_->gpu_id);
info.weights_.SetDevice(ctx_->gpu_id);
common::OptionalWeights weight{ctx_->IsCPU() ? info.weights_.ConstHostSpan()
: info.weights_.ConstDeviceSpan()};
preds.SetDevice(ctx_->gpu_id);
auto predt = linalg::MakeVec(&preds);
auto n_samples = info.num_row_;
alpha_.SetDevice(ctx_->gpu_id);
auto alpha = ctx_->IsCPU() ? alpha_.ConstHostSpan() : alpha_.ConstDeviceSpan();
linalg::ElementWiseKernel(
ctx_, gpair, [=] XGBOOST_DEVICE(std::size_t i, GradientPair const&) mutable {
auto idx = linalg::UnravelIndex(
i, {n_samples, static_cast<SizeT>(alpha.size()), n_targets / alpha.size()});
// std::tie is not available for cuda kernel.
std::size_t sample_id = std::get<0>(idx);
std::size_t quantile_id = std::get<1>(idx);
std::size_t target_id = std::get<2>(idx);
auto d = predt(i) - labels(sample_id, target_id);
auto h = weight[sample_id];
if (d >= 0) {
auto g = (1.0f - alpha[quantile_id]) * weight[sample_id];
gpair(sample_id, quantile_id, target_id) = GradientPair{g, h};
} else {
auto g = (-alpha[quantile_id] * weight[sample_id]);
gpair(sample_id, quantile_id, target_id) = GradientPair{g, h};
}
});
}
void InitEstimation(MetaInfo const& info, linalg::Vector<float>* base_score) const override {
CHECK(!alpha_.Empty());
auto n_targets = this->Targets(info);
base_score->SetDevice(ctx_->gpu_id);
base_score->Reshape(n_targets);
double sw{0};
if (ctx_->IsCPU()) {
auto quantiles = base_score->HostView();
auto h_weights = info.weights_.ConstHostVector();
if (info.weights_.Empty()) {
sw = info.num_row_;
} else {
sw = std::accumulate(std::cbegin(h_weights), std::cend(h_weights), 0.0);
}
for (bst_target_t t{0}; t < n_targets; ++t) {
auto alpha = param_.quantile_alpha[t];
auto h_labels = info.labels.HostView();
if (h_weights.empty()) {
quantiles(t) =
common::Quantile(ctx_, alpha, linalg::cbegin(h_labels), linalg::cend(h_labels));
} else {
CHECK_EQ(h_weights.size(), h_labels.Size());
quantiles(t) = common::WeightedQuantile(ctx_, alpha, linalg::cbegin(h_labels),
linalg::cend(h_labels), std::cbegin(h_weights));
}
}
} else {
#if defined(XGBOOST_USE_CUDA)
alpha_.SetDevice(ctx_->gpu_id);
auto d_alpha = alpha_.ConstDeviceSpan();
auto d_labels = info.labels.View(ctx_->gpu_id);
auto seg_it = dh::MakeTransformIterator<std::size_t>(
thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) { return i * d_labels.Shape(0); });
CHECK_EQ(d_labels.Shape(1), 1);
auto val_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
auto sample_idx = i % d_labels.Shape(0);
return d_labels(sample_idx, 0);
});
auto n = d_labels.Size() * d_alpha.size();
CHECK_EQ(base_score->Size(), d_alpha.size());
if (info.weights_.Empty()) {
common::SegmentedQuantile(ctx_, d_alpha.data(), seg_it, seg_it + d_alpha.size() + 1, val_it,
val_it + n, base_score->Data());
sw = info.num_row_;
} else {
info.weights_.SetDevice(ctx_->gpu_id);
auto d_weights = info.weights_.ConstDeviceSpan();
auto weight_it = dh::MakeTransformIterator<float>(thrust::make_counting_iterator(0ul),
[=] XGBOOST_DEVICE(std::size_t i) {
auto sample_idx = i % d_labels.Shape(0);
return d_weights[sample_idx];
});
common::SegmentedWeightedQuantile(ctx_, d_alpha.data(), seg_it, seg_it + d_alpha.size() + 1,
val_it, val_it + n, weight_it, weight_it + n,
base_score->Data());
sw = dh::Reduce(ctx_->CUDACtx()->CTP(), dh::tcbegin(d_weights), dh::tcend(d_weights), 0.0,
thrust::plus<double>{});
}
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
// For multiple quantiles, we should extend the base score to a vector instead of
// computing the average. For now, this is a workaround.
linalg::Vector<float> temp;
common::Mean(ctx_, *base_score, &temp);
double meanq = temp(0) * sw;
collective::Allreduce<collective::Operation::kSum>(&meanq, 1);
collective::Allreduce<collective::Operation::kSum>(&sw, 1);
meanq /= (sw + kRtEps);
base_score->Reshape(1);
base_score->Data()->Fill(meanq);
}
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
RegTree* p_tree) const override {
auto alpha = param_.quantile_alpha[group_idx];
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, alpha, p_tree);
}
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
param_.Validate();
this->alpha_.HostVector() = param_.quantile_alpha.Get();
}
ObjInfo Task() const override { return {ObjInfo::kRegression, true, true}; }
static char const* Name() { return "reg:quantileerror"; }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(Name());
out["quantile_loss_param"] = ToJson(param_);
}
void LoadConfig(Json const& in) override {
CHECK_EQ(get<String const>(in["name"]), Name());
FromJson(in["quantile_loss_param"], &param_);
alpha_.HostVector() = param_.quantile_alpha.Get();
}
const char* DefaultEvalMetric() const override { return "quantile"; }
Json DefaultMetricConfig() const override {
CHECK(param_.GetInitialised());
Json config{Object{}};
config["name"] = String{this->DefaultEvalMetric()};
config["quantile_loss_param"] = ToJson(param_);
return config;
}
};
XGBOOST_REGISTER_OBJECTIVE(QuantileRegression, QuantileRegression::Name())
.describe("Regression with quantile loss.")
.set_body([]() { return new QuantileRegression(); });
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(quantile_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
} // namespace obj
} // namespace xgboost

View File

@@ -1,15 +1,16 @@
/*!
* Copyright 2017-2022 XGBoost contributors
/**
* Copyright 2017-2023 by XGBoost contributors
*/
#ifndef XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#define XGBOOST_OBJECTIVE_REGRESSION_LOSS_H_
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <cmath>
#include "../common/math.h"
#include "xgboost/data.h" // MetaInfo
#include "xgboost/logging.h"
#include "xgboost/task.h" // ObjInfo
namespace xgboost {
@@ -105,7 +106,6 @@ struct LogisticRaw : public LogisticRegression {
static ObjInfo Info() { return ObjInfo::kRegression; }
};
} // namespace obj
} // namespace xgboost

View File

@@ -744,18 +744,7 @@ class MeanAbsoluteError : public ObjFunction {
void UpdateTreeLeaf(HostDeviceVector<bst_node_t> const& position, MetaInfo const& info,
HostDeviceVector<float> const& prediction, std::int32_t group_idx,
RegTree* p_tree) const override {
if (ctx_->IsCPU()) {
auto const& h_position = position.ConstHostVector();
detail::UpdateTreeLeafHost(ctx_, h_position, group_idx, info, prediction, 0.5, p_tree);
} else {
#if defined(XGBOOST_USE_CUDA)
position.SetDevice(ctx_->gpu_id);
auto d_position = position.ConstDeviceSpan();
detail::UpdateTreeLeafDevice(ctx_, d_position, group_idx, info, prediction, 0.5, p_tree);
#else
common::AssertGPUSupport();
#endif // defined(XGBOOST_USE_CUDA)
}
::xgboost::obj::UpdateTreeLeaf(ctx_, position, group_idx, info, prediction, 0.5, p_tree);
}
const char* DefaultEvalMetric() const override { return "mae"; }