Define the new device parameter. (#9362)
This commit is contained in:
@@ -34,7 +34,7 @@ class TestLoadPickle:
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config["learner"]["generic_param"]["gpu_id"] == "-1"
|
||||
assert config["learner"]["generic_param"]["device"] == "cpu"
|
||||
|
||||
def test_context_is_preserved(self) -> None:
|
||||
"""Test the device context is preserved after pickling."""
|
||||
@@ -42,14 +42,14 @@ class TestLoadPickle:
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
assert config["learner"]["generic_param"]["device"] == "cuda:0"
|
||||
|
||||
def test_wrap_gpu_id(self) -> None:
|
||||
assert os.environ["CUDA_VISIBLE_DEVICES"] == "0"
|
||||
bst = load_pickle(model_path)
|
||||
config = bst.save_config()
|
||||
config = json.loads(config)
|
||||
assert config["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
assert config["learner"]["generic_param"]["device"] == "cuda:0"
|
||||
|
||||
x, y = build_dataset()
|
||||
test_x = xgb.DMatrix(x)
|
||||
|
||||
Reference in New Issue
Block a user