Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -93,17 +93,18 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
}
}
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
gbm::GBLinearModel *model, double sum_instance_weight) override {
void Update(linalg::Matrix<GradientPair> *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model,
double sum_instance_weight) override {
tparam_.DenormalizePenalties(sum_instance_weight);
monitor_.Start("LazyInitDevice");
this->LazyInitDevice(p_fmat, *(model->learner_model_param));
monitor_.Stop("LazyInitDevice");
monitor_.Start("UpdateGpair");
// Update gpair
if (ctx_->gpu_id >= 0) {
this->UpdateGpair(in_gpair->ConstHostVector());
if (ctx_->IsCUDA()) {
this->UpdateGpair(in_gpair->Data()->ConstHostVector());
}
monitor_.Stop("UpdateGpair");
@@ -111,15 +112,15 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
this->UpdateBias(model);
monitor_.Stop("UpdateBias");
// prepare for updating the weights
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm,
tparam_.reg_lambda_denorm, coord_param_.top_k);
selector_->Setup(ctx_, *model, in_gpair->Data()->ConstHostVector(), p_fmat,
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm, coord_param_.top_k);
monitor_.Start("UpdateFeature");
for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group;
++group_idx) {
for (auto i = 0U; i < model->learner_model_param->num_feature; i++) {
auto fidx =
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->Data()->ConstHostVector(),
p_fmat, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
if (fidx < 0) break;
this->UpdateFeature(fidx, group_idx, model);
}