Better message when no GPU is found. (#5594)

This commit is contained in:
Jiaming Yuan 2020-04-26 10:00:57 +08:00 committed by GitHub
parent 8dfe7b3686
commit 7d93932423
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 0 deletions

View File

@ -407,6 +407,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
if (tparam_.predictor != PredictorType::kAuto) { if (tparam_.predictor != PredictorType::kAuto) {
if (tparam_.predictor == PredictorType::kGPUPredictor) { if (tparam_.predictor == PredictorType::kGPUPredictor) {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
CHECK(gpu_predictor_); CHECK(gpu_predictor_);
return gpu_predictor_; return gpu_predictor_;
#else #else
@ -429,6 +430,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
// Use GPU Predictor if data is already on device and gpu_id is set. // Use GPU Predictor if data is already on device and gpu_id is set.
if (on_device && generic_param_->gpu_id >= 0) { if (on_device && generic_param_->gpu_id >= 0) {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
CHECK(gpu_predictor_); CHECK(gpu_predictor_);
return gpu_predictor_; return gpu_predictor_;
#else #else
@ -454,6 +456,7 @@ GBTree::GetPredictor(HostDeviceVector<float> const *out_pred,
if (tparam_.tree_method == TreeMethod::kGPUHist) { if (tparam_.tree_method == TreeMethod::kGPUHist) {
#if defined(XGBOOST_USE_CUDA) #if defined(XGBOOST_USE_CUDA)
CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost.";
CHECK(gpu_predictor_); CHECK(gpu_predictor_);
return gpu_predictor_; return gpu_predictor_;
#else #else

View File

@ -187,8 +187,13 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) {
// raw model), number of available GPUs could be different. Wrap around it. // raw model), number of available GPUs could be different. Wrap around it.
int32_t n_gpus = common::AllVisibleGPUs(); int32_t n_gpus = common::AllVisibleGPUs();
if (n_gpus == 0) { if (n_gpus == 0) {
if (gpu_id != kCpuId) {
LOG(WARNING) << "No visible GPU is found, setting `gpu_id` to -1";
}
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}}); this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}});
} else if (gpu_id != kCpuId && gpu_id >= n_gpus) { } else if (gpu_id != kCpuId && gpu_id >= n_gpus) {
LOG(WARNING) << "Only " << n_gpus
<< " GPUs are visible, setting `gpu_id` to " << gpu_id % n_gpus;
this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(gpu_id % n_gpus)}}); this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(gpu_id % n_gpus)}});
} }
#else #else

View File

@ -2,11 +2,17 @@
`test_gpu_with_dask.py`''' `test_gpu_with_dask.py`'''
import unittest import unittest
import os import os
import numpy as np
import xgboost as xgb import xgboost as xgb
import json import json
import pytest
import sys
from test_gpu_pickling import build_dataset, model_path, load_pickle from test_gpu_pickling import build_dataset, model_path, load_pickle
sys.path.append("tests/python")
import test_basic as tb
class TestLoadPickle(unittest.TestCase): class TestLoadPickle(unittest.TestCase):
def test_load_pkl(self): def test_load_pkl(self):
@ -49,3 +55,15 @@ class TestLoadPickle(unittest.TestCase):
test_x = xgb.DMatrix(x) test_x = xgb.DMatrix(x)
res = bst.predict(test_x) res = bst.predict(test_x)
assert len(res) == 10 assert len(res) == 10
def test_training_on_cpu_only_env(self):
assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1'
rng = np.random.RandomState(1994)
X = rng.randn(10, 10)
y = rng.randn(10)
with tb.captured_output() as (out, err):
# Test no thrust exception is thrown
with pytest.raises(xgb.core.XGBoostError):
xgb.train({'tree_method': 'gpu_hist'}, xgb.DMatrix(X, y))
assert out.getvalue().find('No visible GPU is found') != -1

View File

@ -157,3 +157,14 @@ class TestPickling(unittest.TestCase):
bst.set_param({'predictor': 'cpu_predictor'}) bst.set_param({'predictor': 'cpu_predictor'})
cpu_pred = model.predict(x, output_margin=True) cpu_pred = model.predict(x, output_margin=True)
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5) np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
def test_training_on_cpu_only_env(self):
cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'}
env = os.environ.copy()
env.update(cuda_environment)
args = self.args_template.copy()
args.append(
"./tests/python-gpu/"
"load_pickle.py::TestLoadPickle::test_training_on_cpu_only_env")
status = subprocess.call(args, env=env)
assert status == 0