Define the new device parameter. (#9362)
This commit is contained in:
@@ -113,14 +113,6 @@ class TestPickling:
|
||||
param = {"tree_method": "gpu_hist", "verbosity": 1}
|
||||
bst = xgb.train(param, train_x)
|
||||
|
||||
with tm.captured_output() as (out, err):
|
||||
bst.inplace_predict(x)
|
||||
|
||||
# The warning is redirected to Python callback, so it's printed in stdout
|
||||
# instead of stderr.
|
||||
stdout = out.getvalue()
|
||||
assert stdout.find("mismatched devices") != -1
|
||||
|
||||
save_pickle(bst, model_path)
|
||||
|
||||
args = self.args_template.copy()
|
||||
@@ -177,7 +169,7 @@ class TestPickling:
|
||||
|
||||
# Switch to CPU predictor
|
||||
bst = model.get_booster()
|
||||
tm.set_ordinal(-1, bst)
|
||||
bst.set_param({"device": "cpu"})
|
||||
cpu_pred = model.predict(x, output_margin=True)
|
||||
np.testing.assert_allclose(cpu_pred, gpu_pred, rtol=1e-5)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user