Avoid calling CUDA code on CPU for linear model. (#7154)
This commit is contained in:
parent
ba69244a94
commit
3a4f51f39f
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2014-2020 by Contributors
|
* Copyright 2014-2021 by Contributors
|
||||||
* \file gblinear.cc
|
* \file gblinear.cc
|
||||||
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
|
* \brief Implementation of Linear booster, with L1/L2 regularization: Elastic Net
|
||||||
* the update rule is parallel coordinate descent (shotgun)
|
* the update rule is parallel coordinate descent (shotgun)
|
||||||
@ -37,6 +37,17 @@ struct GBLinearTrainParam : public XGBoostParameter<GBLinearTrainParam> {
|
|||||||
std::string updater;
|
std::string updater;
|
||||||
float tolerance;
|
float tolerance;
|
||||||
size_t max_row_perbatch;
|
size_t max_row_perbatch;
|
||||||
|
|
||||||
|
void CheckGPUSupport() {
|
||||||
|
auto n_gpus = common::AllVisibleGPUs();
|
||||||
|
if (n_gpus == 0 && this->updater == "gpu_coord_descent") {
|
||||||
|
common::AssertGPUSupport();
|
||||||
|
this->UpdateAllowUnknown(Args{{"updater", "coord_descent"}});
|
||||||
|
LOG(WARNING) << "Loading configuration on a CPU only machine. Changing "
|
||||||
|
"updater to `coord_descent`.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
DMLC_DECLARE_PARAMETER(GBLinearTrainParam) {
|
DMLC_DECLARE_PARAMETER(GBLinearTrainParam) {
|
||||||
DMLC_DECLARE_FIELD(updater)
|
DMLC_DECLARE_FIELD(updater)
|
||||||
.set_default("shotgun")
|
.set_default("shotgun")
|
||||||
@ -74,12 +85,10 @@ class GBLinear : public GradientBooster {
|
|||||||
model_.Configure(cfg);
|
model_.Configure(cfg);
|
||||||
}
|
}
|
||||||
param_.UpdateAllowUnknown(cfg);
|
param_.UpdateAllowUnknown(cfg);
|
||||||
|
param_.CheckGPUSupport();
|
||||||
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
||||||
updater_->Configure(cfg);
|
updater_->Configure(cfg);
|
||||||
monitor_.Init("GBLinear");
|
monitor_.Init("GBLinear");
|
||||||
if (param_.updater == "gpu_coord_descent") {
|
|
||||||
common::AssertGPUSupport();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t BoostedRounds() const override {
|
int32_t BoostedRounds() const override {
|
||||||
@ -110,6 +119,7 @@ class GBLinear : public GradientBooster {
|
|||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
CHECK_EQ(get<String>(in["name"]), "gblinear");
|
CHECK_EQ(get<String>(in["name"]), "gblinear");
|
||||||
FromJson(in["gblinear_train_param"], ¶m_);
|
FromJson(in["gblinear_train_param"], ¶m_);
|
||||||
|
param_.CheckGPUSupport();
|
||||||
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
updater_.reset(LinearUpdater::Create(param_.updater, generic_param_));
|
||||||
this->updater_->LoadConfig(in["updater"]);
|
this->updater_->LoadConfig(in["updater"]);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,12 +28,6 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_coordinate);
|
|||||||
|
|
||||||
class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
|
||||||
public:
|
public:
|
||||||
~GPUCoordinateUpdater() { // NOLINT
|
|
||||||
if (learner_param_->gpu_id >= 0) {
|
|
||||||
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set training parameter
|
// set training parameter
|
||||||
void Configure(Args const& args) override {
|
void Configure(Args const& args) override {
|
||||||
tparam_.UpdateAllowUnknown(args);
|
tparam_.UpdateAllowUnknown(args);
|
||||||
|
|||||||
@ -19,8 +19,15 @@ class TestLoadPickle:
|
|||||||
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
|
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
|
||||||
bst = load_pickle(model_path)
|
bst = load_pickle(model_path)
|
||||||
x, y = build_dataset()
|
x, y = build_dataset()
|
||||||
|
if isinstance(bst, xgb.Booster):
|
||||||
test_x = xgb.DMatrix(x)
|
test_x = xgb.DMatrix(x)
|
||||||
res = bst.predict(test_x)
|
res = bst.predict(test_x)
|
||||||
|
else:
|
||||||
|
res = bst.predict(x)
|
||||||
|
assert len(res) == 10
|
||||||
|
bst.set_params(n_jobs=1) # triggers a re-configuration
|
||||||
|
res = bst.predict(x)
|
||||||
|
|
||||||
assert len(res) == 10
|
assert len(res) == 10
|
||||||
|
|
||||||
def test_predictor_type_is_auto(self):
|
def test_predictor_type_is_auto(self):
|
||||||
|
|||||||
@ -41,13 +41,7 @@ class TestPickling:
|
|||||||
"-s",
|
"-s",
|
||||||
"--fulltrace"]
|
"--fulltrace"]
|
||||||
|
|
||||||
def test_pickling(self):
|
def run_pickling(self, bst) -> None:
|
||||||
x, y = build_dataset()
|
|
||||||
train_x = xgb.DMatrix(x, label=y)
|
|
||||||
param = {'tree_method': 'gpu_hist',
|
|
||||||
'verbosity': 1}
|
|
||||||
bst = xgb.train(param, train_x)
|
|
||||||
|
|
||||||
save_pickle(bst, model_path)
|
save_pickle(bst, model_path)
|
||||||
args = [
|
args = [
|
||||||
"pytest", "--verbose", "-s", "--fulltrace",
|
"pytest", "--verbose", "-s", "--fulltrace",
|
||||||
@ -71,6 +65,25 @@ class TestPickling:
|
|||||||
assert status == 0
|
assert status == 0
|
||||||
os.remove(model_path)
|
os.remove(model_path)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_sklearn())
|
||||||
|
def test_pickling(self):
|
||||||
|
x, y = build_dataset()
|
||||||
|
train_x = xgb.DMatrix(x, label=y)
|
||||||
|
|
||||||
|
param = {'tree_method': 'gpu_hist', "gpu_id": 0}
|
||||||
|
bst = xgb.train(param, train_x)
|
||||||
|
self.run_pickling(bst)
|
||||||
|
|
||||||
|
bst = xgb.XGBRegressor(**param).fit(x, y)
|
||||||
|
self.run_pickling(bst)
|
||||||
|
|
||||||
|
param = {"booster": "gblinear", "updater": "gpu_coord_descent", "gpu_id": 0}
|
||||||
|
bst = xgb.train(param, train_x)
|
||||||
|
self.run_pickling(bst)
|
||||||
|
|
||||||
|
bst = xgb.XGBRegressor(**param).fit(x, y)
|
||||||
|
self.run_pickling(bst)
|
||||||
|
|
||||||
@pytest.mark.mgpu
|
@pytest.mark.mgpu
|
||||||
def test_wrap_gpu_id(self):
|
def test_wrap_gpu_id(self):
|
||||||
X, y = build_dataset()
|
X, y = build_dataset()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user