Define the new device parameter. (#9362)

This commit is contained in:
Jiaming Yuan
2023-07-13 19:30:25 +08:00
committed by GitHub
parent 2d0cd2817e
commit 04aff3af8e
63 changed files with 827 additions and 477 deletions

View File

@@ -34,7 +34,7 @@ class TestLoadPickle:
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config["learner"]["generic_param"]["gpu_id"] == "-1"
assert config["learner"]["generic_param"]["device"] == "cpu"
def test_context_is_preserved(self) -> None:
"""Test the device context is preserved after pickling."""
@@ -42,14 +42,14 @@ class TestLoadPickle:
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config["learner"]["generic_param"]["gpu_id"] == "0"
assert config["learner"]["generic_param"]["device"] == "cuda:0"
def test_wrap_gpu_id(self) -> None:
assert os.environ["CUDA_VISIBLE_DEVICES"] == "0"
bst = load_pickle(model_path)
config = bst.save_config()
config = json.loads(config)
assert config["learner"]["generic_param"]["gpu_id"] == "0"
assert config["learner"]["generic_param"]["device"] == "cuda:0"
x, y = build_dataset()
test_x = xgb.DMatrix(x)