Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -27,8 +27,8 @@ class HingeObj : public ObjFunction {
|
||||
void Configure(Args const&) override {}
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info, int /*iter*/,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
|
||||
std::int32_t /*iter*/, linalg::Matrix<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||
<< "labels are not correctly provided"
|
||||
@@ -41,7 +41,8 @@ class HingeObj : public ObjFunction {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
out_gpair->Resize(ndata);
|
||||
CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target for `binary:hinge` is not yet supported.";
|
||||
out_gpair->Reshape(ndata, 1);
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
@@ -63,7 +64,7 @@ class HingeObj : public ObjFunction {
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, this->ctx_->Threads(),
|
||||
ctx_->gpu_id).Eval(
|
||||
out_gpair, &preds, info.labels.Data(), &info.weights_);
|
||||
out_gpair->Data(), &preds, info.labels.Data(), &info.weights_);
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
|
||||
Reference in New Issue
Block a user