[backport] Fix device dispatch for linear updater. (#9507) (#9532)

This commit is contained in:
Jiaming Yuan
2023-08-29 15:10:43 +08:00
committed by GitHub
parent 4301558a57
commit a0d3573c74
6 changed files with 80 additions and 35 deletions

View File

@@ -9,8 +9,7 @@
#include "coordinate_common.h"
#include "xgboost/json.h"
namespace xgboost {
namespace linear {
namespace xgboost::linear {
DMLC_REGISTER_PARAMETER(CoordinateParam);
DMLC_REGISTRY_FILE_TAG(updater_coordinate);
@@ -39,8 +38,9 @@ class CoordinateUpdater : public LinearUpdater {
FromJson(config.at("linear_train_param"), &tparam_);
FromJson(config.at("coordinate_param"), &cparam_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
void SaveConfig(Json *p_out) const override {
LOG(DEBUG) << "Save config for CPU updater.";
auto &out = *p_out;
out["linear_train_param"] = ToJson(tparam_);
out["coordinate_param"] = ToJson(cparam_);
}
@@ -99,5 +99,4 @@ class CoordinateUpdater : public LinearUpdater {
XGBOOST_REGISTER_LINEAR_UPDATER(CoordinateUpdater, "coord_descent")
.describe("Update linear model according to coordinate descent algorithm.")
.set_body([]() { return new CoordinateUpdater(); });
} // namespace linear
} // namespace xgboost
} // namespace xgboost::linear

View File

@@ -15,8 +15,7 @@
#include "../common/timer.h"
#include "./param.h"
namespace xgboost {
namespace linear {
namespace xgboost::linear {
DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
@@ -29,7 +28,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
public:
// set training parameter
void Configure(Args const& args) override {
void Configure(Args const &args) override {
tparam_.UpdateAllowUnknown(args);
coord_param_.UpdateAllowUnknown(args);
selector_.reset(FeatureSelector::Create(tparam_.feature_selector));
@@ -41,8 +40,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
FromJson(config.at("linear_train_param"), &tparam_);
FromJson(config.at("coordinate_param"), &coord_param_);
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
void SaveConfig(Json *p_out) const override {
LOG(DEBUG) << "Save config for GPU updater.";
auto &out = *p_out;
out["linear_train_param"] = ToJson(tparam_);
out["coordinate_param"] = ToJson(coord_param_);
}
@@ -101,10 +101,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
monitor_.Stop("LazyInitDevice");
monitor_.Start("UpdateGpair");
auto &in_gpair_host = in_gpair->ConstHostVector();
// Update gpair
if (ctx_->gpu_id >= 0) {
this->UpdateGpair(in_gpair_host);
this->UpdateGpair(in_gpair->ConstHostVector());
}
monitor_.Stop("UpdateGpair");
@@ -249,5 +248,4 @@ XGBOOST_REGISTER_LINEAR_UPDATER(GPUCoordinateUpdater, "gpu_coord_descent")
"Update linear model according to coordinate descent algorithm. GPU "
"accelerated.")
.set_body([]() { return new GPUCoordinateUpdater(); });
} // namespace linear
} // namespace xgboost
} // namespace xgboost::linear