Fix wrapping GPU ID and prevent data copying. (#5160)
* Removed some data copying. * Make sure gpu_id is valid before any configuration is carried out.
This commit is contained in:
@@ -37,3 +37,15 @@ class TestLoadPickle(unittest.TestCase):
|
||||
config = json.loads(config)
|
||||
assert config['learner']['gradient_booster']['gbtree_train_param'][
|
||||
'predictor'] == 'gpu_predictor'
|
||||
|
||||
def test_wrap_gpu_id(self):
|
||||
assert os.environ['CUDA_VISIBLE_DEVICES'] == '0'
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config['learner']['generic_param']['gpu_id'] == '0'
|
||||
|
||||
x, y = build_dataset()
|
||||
test_x = xgb.DMatrix(x)
|
||||
res = bst.predict(test_x)
|
||||
assert len(res) == 10
|
||||
|
||||
Reference in New Issue
Block a user