Define the new device parameter. (#9362)
This commit is contained in:
@@ -43,10 +43,16 @@ class TestGPUEvalMetrics:
|
||||
num_boost_round=10,
|
||||
)
|
||||
cpu_auc = float(booster.eval(Xy).split(":")[1])
|
||||
booster.set_param({"gpu_id": "0"})
|
||||
assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
booster.set_param({"device": "cuda:0"})
|
||||
assert (
|
||||
json.loads(booster.save_config())["learner"]["generic_param"]["device"]
|
||||
== "cuda:0"
|
||||
)
|
||||
gpu_auc = float(booster.eval(Xy).split(":")[1])
|
||||
assert json.loads(booster.save_config())["learner"]["generic_param"]["gpu_id"] == "0"
|
||||
assert (
|
||||
json.loads(booster.save_config())["learner"]["generic_param"]["device"]
|
||||
== "cuda:0"
|
||||
)
|
||||
|
||||
np.testing.assert_allclose(cpu_auc, gpu_auc)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user