Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -11,8 +11,7 @@ import testing as tm
|
||||
class TestDeviceQuantileDMatrix:
|
||||
def test_dmatrix_numpy_init(self):
|
||||
data = np.random.randn(5, 5)
|
||||
with pytest.raises(TypeError,
|
||||
match='is not supported for DeviceQuantileDMatrix'):
|
||||
with pytest.raises(TypeError, match='is not supported'):
|
||||
xgb.DeviceQuantileDMatrix(data, np.ones(5, dtype=np.float64))
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
|
||||
@@ -141,6 +141,13 @@ class TestGPUPredict:
|
||||
assert np.allclose(cpu_train_score, gpu_train_score)
|
||||
assert np.allclose(cpu_test_score, gpu_test_score)
|
||||
|
||||
def run_inplace_base_margin(self, booster, dtrain, X, base_margin):
|
||||
import cupy as cp
|
||||
dtrain.set_info(base_margin=base_margin)
|
||||
from_inplace = booster.inplace_predict(data=X, base_margin=base_margin)
|
||||
from_dmatrix = booster.predict(dtrain)
|
||||
cp.testing.assert_allclose(from_inplace, from_dmatrix)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cupy())
|
||||
def test_inplace_predict_cupy(self):
|
||||
import cupy as cp
|
||||
@@ -175,6 +182,9 @@ class TestGPUPredict:
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_dense)
|
||||
|
||||
base_margin = cp_rng.randn(rows)
|
||||
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_cudf())
|
||||
def test_inplace_predict_cudf(self):
|
||||
import cupy as cp
|
||||
@@ -208,6 +218,9 @@ class TestGPUPredict:
|
||||
for i in range(10):
|
||||
run_threaded_predict(X, rows, predict_df)
|
||||
|
||||
base_margin = cudf.Series(rng.randn(rows))
|
||||
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
|
||||
|
||||
@given(strategies.integers(1, 10),
|
||||
tm.dataset_strategy, shap_parameter_strategy)
|
||||
@settings(deadline=None)
|
||||
|
||||
Reference in New Issue
Block a user