[backport] Fix device dispatch for linear updater. (#9507) (#9532)

This commit is contained in:
Jiaming Yuan
2023-08-29 15:10:43 +08:00
committed by GitHub
parent 4301558a57
commit a0d3573c74
6 changed files with 80 additions and 35 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2022 by XGBoost Contributors
/**
* Copyright 2014-2023, XGBoost Contributors
* \file gblinear.cc
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
* the update rule is parallel coordinate descent (shotgun)
@@ -26,9 +26,9 @@
#include "../common/timer.h"
#include "../common/common.h"
#include "../common/threading_utils.h"
#include "../common/error_msg.h"
namespace xgboost {
namespace gbm {
namespace xgboost::gbm {
DMLC_REGISTRY_FILE_TAG(gblinear);
@@ -83,7 +83,16 @@ class GBLinear : public GradientBooster {
}
param_.UpdateAllowUnknown(cfg);
param_.CheckGPUSupport();
updater_.reset(LinearUpdater::Create(param_.updater, ctx_));
if (param_.updater == "gpu_coord_descent") {
LOG(WARNING) << error::DeprecatedFunc("gpu_coord_descent", "2.0.0",
R"(device="cuda", updater="coord_descent")");
}
if (param_.updater == "coord_descent" && ctx_->IsCUDA()) {
updater_.reset(LinearUpdater::Create("gpu_coord_descent", ctx_));
} else {
updater_.reset(LinearUpdater::Create(param_.updater, ctx_));
}
updater_->Configure(cfg);
monitor_.Init("GBLinear");
}
@@ -354,5 +363,4 @@ XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
.set_body([](LearnerModelParam const* booster_config, Context const* ctx) {
return new GBLinear(booster_config, ctx);
});
} // namespace gbm
} // namespace xgboost
} // namespace xgboost::gbm