Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -93,17 +93,18 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
}
|
||||
}
|
||||
|
||||
void Update(HostDeviceVector<GradientPair> *in_gpair, DMatrix *p_fmat,
|
||||
gbm::GBLinearModel *model, double sum_instance_weight) override {
|
||||
void Update(linalg::Matrix<GradientPair> *in_gpair, DMatrix *p_fmat, gbm::GBLinearModel *model,
|
||||
double sum_instance_weight) override {
|
||||
tparam_.DenormalizePenalties(sum_instance_weight);
|
||||
monitor_.Start("LazyInitDevice");
|
||||
this->LazyInitDevice(p_fmat, *(model->learner_model_param));
|
||||
monitor_.Stop("LazyInitDevice");
|
||||
|
||||
monitor_.Start("UpdateGpair");
|
||||
|
||||
// Update gpair
|
||||
if (ctx_->gpu_id >= 0) {
|
||||
this->UpdateGpair(in_gpair->ConstHostVector());
|
||||
if (ctx_->IsCUDA()) {
|
||||
this->UpdateGpair(in_gpair->Data()->ConstHostVector());
|
||||
}
|
||||
monitor_.Stop("UpdateGpair");
|
||||
|
||||
@@ -111,15 +112,15 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||
this->UpdateBias(model);
|
||||
monitor_.Stop("UpdateBias");
|
||||
// prepare for updating the weights
|
||||
selector_->Setup(ctx_, *model, in_gpair->ConstHostVector(), p_fmat, tparam_.reg_alpha_denorm,
|
||||
tparam_.reg_lambda_denorm, coord_param_.top_k);
|
||||
selector_->Setup(ctx_, *model, in_gpair->Data()->ConstHostVector(), p_fmat,
|
||||
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm, coord_param_.top_k);
|
||||
monitor_.Start("UpdateFeature");
|
||||
for (uint32_t group_idx = 0; group_idx < model->learner_model_param->num_output_group;
|
||||
++group_idx) {
|
||||
for (auto i = 0U; i < model->learner_model_param->num_feature; i++) {
|
||||
auto fidx =
|
||||
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->ConstHostVector(), p_fmat,
|
||||
tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
|
||||
selector_->NextFeature(ctx_, i, *model, group_idx, in_gpair->Data()->ConstHostVector(),
|
||||
p_fmat, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm);
|
||||
if (fidx < 0) break;
|
||||
this->UpdateFeature(fidx, group_idx, model);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user