Fix dart inplace prediction with GPU input. (#6777)

* Fix dart inplace predict with data on GPU, which might trigger a fatal check
for device access right.
* Avoid copying data whenever possible.
This commit is contained in:
Jiaming Yuan
2021-03-25 12:00:32 +08:00
committed by GitHub
parent 1d90577800
commit a7083d3c13
6 changed files with 135 additions and 25 deletions

View File

@@ -312,3 +312,33 @@ class TestGPUPredict:
pred = bst.predict(dtrain)
rmse = mean_squared_error(y_true=y, y_pred=pred, squared=False)
np.testing.assert_almost_equal(rmse, eval_history['train']['rmse'][-1], decimal=5)
def test_predict_dart(self):
import cupy as cp
rng = cp.random.RandomState(1994)
n_samples = 1000
X = rng.randn(n_samples, 10)
y = rng.randn(n_samples)
Xy = xgb.DMatrix(X, y)
booster = xgb.train(
{
"tree_method": "gpu_hist",
"booster": "dart",
"rate_drop": 0.5,
},
Xy,
num_boost_round=32
)
# predictor=auto
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
copied = cp.array(copied)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)
booster.set_param({"predictor": "gpu_predictor"})
inplace = booster.inplace_predict(X)
copied = booster.predict(Xy)
copied = cp.array(copied)
cp.testing.assert_allclose(inplace, copied, atol=1e-6)