Fix wrapping GPU ID and prevent data copying. (#5160)

* Removed some data copying.

* Make sure gpu_id is valid before any configuration is carried out.
This commit is contained in:
Jiaming Yuan
2019-12-27 16:51:08 +08:00
committed by GitHub
parent ee81ba8e1f
commit 61286c6e8f
7 changed files with 55 additions and 17 deletions

View File

@@ -37,3 +37,15 @@ class TestLoadPickle(unittest.TestCase):
config = json.loads(config)
assert config['learner']['gradient_booster']['gbtree_train_param'][
'predictor'] == 'gpu_predictor'
def test_wrap_gpu_id(self):
assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config['learner']['generic_param']['gpu_id'] == '0'
x, y = build_dataset()
test_x = xgb.DMatrix(x)
res = bst.predict(test_x)
assert len(res) == 10