diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 08bfd7ec2..61a3021cb 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -407,6 +407,7 @@ GBTree::GetPredictor(HostDeviceVector const *out_pred, if (tparam_.predictor != PredictorType::kAuto) { if (tparam_.predictor == PredictorType::kGPUPredictor) { #if defined(XGBOOST_USE_CUDA) + CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost."; CHECK(gpu_predictor_); return gpu_predictor_; #else @@ -429,6 +430,7 @@ GBTree::GetPredictor(HostDeviceVector const *out_pred, // Use GPU Predictor if data is already on device and gpu_id is set. if (on_device && generic_param_->gpu_id >= 0) { #if defined(XGBOOST_USE_CUDA) + CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost."; CHECK(gpu_predictor_); return gpu_predictor_; #else @@ -454,6 +456,7 @@ GBTree::GetPredictor(HostDeviceVector const *out_pred, if (tparam_.tree_method == TreeMethod::kGPUHist) { #if defined(XGBOOST_USE_CUDA) + CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost."; CHECK(gpu_predictor_); return gpu_predictor_; #else diff --git a/src/learner.cc b/src/learner.cc index e4c925ebf..de9620c9a 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -187,8 +187,13 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) { // raw model), number of available GPUs could be different. Wrap around it. int32_t n_gpus = common::AllVisibleGPUs(); if (n_gpus == 0) { + if (gpu_id != kCpuId) { + LOG(WARNING) << "No visible GPU is found, setting `gpu_id` to -1"; + } this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(kCpuId)}}); } else if (gpu_id != kCpuId && gpu_id >= n_gpus) { + LOG(WARNING) << "Only " << n_gpus + << " GPUs are visible, setting `gpu_id` to " << gpu_id % n_gpus; this->UpdateAllowUnknown(Args{{"gpu_id", std::to_string(gpu_id % n_gpus)}}); } #else diff --git a/tests/python-gpu/load_pickle.py b/tests/python-gpu/load_pickle.py index 447df3034..2a75e612d 100644 --- a/tests/python-gpu/load_pickle.py +++ b/tests/python-gpu/load_pickle.py @@ -2,11 +2,17 @@ `test_gpu_with_dask.py`''' import unittest import os +import numpy as np import xgboost as xgb import json +import pytest +import sys from test_gpu_pickling import build_dataset, model_path, load_pickle +sys.path.append("tests/python") +import test_basic as tb + class TestLoadPickle(unittest.TestCase): def test_load_pkl(self): @@ -49,3 +55,15 @@ class TestLoadPickle(unittest.TestCase): test_x = xgb.DMatrix(x) res = bst.predict(test_x) assert len(res) == 10 + + def test_training_on_cpu_only_env(self): + assert os.environ['CUDA_VISIBLE_DEVICES'] == '-1' + rng = np.random.RandomState(1994) + X = rng.randn(10, 10) + y = rng.randn(10) + with tb.captured_output() as (out, err): + # Test no thrust exception is thrown + with pytest.raises(xgb.core.XGBoostError): + xgb.train({'tree_method': 'gpu_hist'}, xgb.DMatrix(X, y)) + + assert out.getvalue().find('No visible GPU is found') != -1 diff --git a/tests/python-gpu/test_gpu_pickling.py b/tests/python-gpu/test_gpu_pickling.py index b8cc56203..9fe12ffbf 100644 --- a/tests/python-gpu/test_gpu_pickling.py +++ b/tests/python-gpu/test_gpu_pickling.py @@ -157,3 +157,14 @@ class TestPickling(unittest.TestCase): bst.set_param({'predictor': 'cpu_predictor'}) cpu_pred = model.predict(x, output_margin=True) np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5) + + def test_training_on_cpu_only_env(self): + cuda_environment = {'CUDA_VISIBLE_DEVICES': '-1'} + env = os.environ.copy() + env.update(cuda_environment) + args = self.args_template.copy() + args.append( + "./tests/python-gpu/" + "load_pickle.py::TestLoadPickle::test_training_on_cpu_only_env") + status = subprocess.call(args, env=env) + assert status == 0