Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2022 by Contributors
/**
* Copyright 2019-2023, XGBoost Contributors
* \file aft_obj.cu
* \brief Definition of AFT loss for survival analysis.
* \author Avinash Barnwal, Hyunsu Cho and Toby Hocking
@@ -41,11 +41,9 @@ class AFTObj : public ObjFunction {
ObjInfo Task() const override { return ObjInfo::kSurvival; }
template <typename Distribution>
void GetGradientImpl(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
HostDeviceVector<GradientPair> *out_gpair,
size_t ndata, int device, bool is_null_weight,
float aft_loss_distribution_scale) {
void GetGradientImpl(const HostDeviceVector<bst_float>& preds, const MetaInfo& info,
linalg::Matrix<GradientPair>* out_gpair, size_t ndata, int device,
bool is_null_weight, float aft_loss_distribution_scale) {
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<GradientPair> _out_gpair,
@@ -66,16 +64,17 @@ class AFTObj : public ObjFunction {
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(), device).Eval(
out_gpair, &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
out_gpair->Data(), &preds, &info.labels_lower_bound_, &info.labels_upper_bound_,
&info.weights_);
}
void GetGradient(const HostDeviceVector<bst_float>& preds, const MetaInfo& info, int /*iter*/,
HostDeviceVector<GradientPair>* out_gpair) override {
linalg::Matrix<GradientPair>* out_gpair) override {
const size_t ndata = preds.Size();
CHECK_EQ(info.labels_lower_bound_.Size(), ndata);
CHECK_EQ(info.labels_upper_bound_.Size(), ndata);
out_gpair->Resize(ndata);
out_gpair->SetDevice(ctx_->Device());
out_gpair->Reshape(ndata, 1);
const int device = ctx_->gpu_id;
const float aft_loss_distribution_scale = param_.aft_loss_distribution_scale;
const bool is_null_weight = info.weights_.Size() == 0;