Define the new device parameter. (#9362)
This commit is contained in:
@@ -28,7 +28,7 @@ def run_threaded_predict(X, rows, predict_func):
|
||||
assert f.result()
|
||||
|
||||
|
||||
def run_predict_leaf(gpu_id: int) -> np.ndarray:
|
||||
def run_predict_leaf(device: str) -> np.ndarray:
|
||||
rows = 100
|
||||
cols = 4
|
||||
classes = 5
|
||||
@@ -48,7 +48,7 @@ def run_predict_leaf(gpu_id: int) -> np.ndarray:
|
||||
num_boost_round=num_boost_round,
|
||||
)
|
||||
|
||||
booster = tm.set_ordinal(gpu_id, booster)
|
||||
booster.set_param({"device": device})
|
||||
empty = xgb.DMatrix(np.ones(shape=(0, cols)))
|
||||
empty_leaf = booster.predict(empty, pred_leaf=True)
|
||||
assert empty_leaf.shape[0] == 0
|
||||
@@ -74,14 +74,14 @@ def run_predict_leaf(gpu_id: int) -> np.ndarray:
|
||||
|
||||
# When there's only 1 tree, the output is a 1 dim vector
|
||||
booster = xgb.train({"tree_method": "hist"}, num_boost_round=1, dtrain=m)
|
||||
booster = tm.set_ordinal(gpu_id, booster)
|
||||
booster.set_param({"device": device})
|
||||
assert booster.predict(m, pred_leaf=True).shape == (rows,)
|
||||
|
||||
return leaf
|
||||
|
||||
|
||||
def test_predict_leaf() -> None:
|
||||
run_predict_leaf(-1)
|
||||
run_predict_leaf("cpu")
|
||||
|
||||
|
||||
def test_predict_shape():
|
||||
|
||||
Reference in New Issue
Block a user