Enhance inplace prediction. (#6653)

* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.

* Remove unused doc.
* Add strict output.
This commit is contained in:
Jiaming Yuan
2021-02-02 11:41:46 +08:00
committed by GitHub
parent 87ab1ad607
commit 411592a347
22 changed files with 955 additions and 530 deletions

View File

@@ -80,20 +80,28 @@ def test_predict_leaf():
class TestInplacePredict:
'''Tests for running inplace prediction'''
@classmethod
def setup_class(cls):
cls.rows = 100
cls.cols = 10
cls.rng = np.random.RandomState(1994)
cls.X = cls.rng.randn(cls.rows, cls.cols)
cls.y = cls.rng.randn(cls.rows)
dtrain = xgb.DMatrix(cls.X, cls.y)
cls.booster = xgb.train({'tree_method': 'hist'},
dtrain, num_boost_round=10)
cls.test = xgb.DMatrix(cls.X[:10, ...])
def test_predict(self):
rows = 1000
cols = 10
booster = self.booster
X = self.X
test = self.test
np.random.seed(1994)
X = np.random.randn(rows, cols)
y = np.random.randn(rows)
dtrain = xgb.DMatrix(X, y)
booster = xgb.train({'tree_method': 'hist'},
dtrain, num_boost_round=10)
test = xgb.DMatrix(X[:10, ...])
predt_from_array = booster.inplace_predict(X[:10, ...])
predt_from_dmatrix = booster.predict(test)
@@ -111,7 +119,7 @@ class TestInplacePredict:
return np.all(copied_predt == inplace_predt)
for i in range(10):
run_threaded_predict(X, rows, predict_dense)
run_threaded_predict(X, self.rows, predict_dense)
def predict_csr(x):
inplace_predt = booster.inplace_predict(sparse.csr_matrix(x))
@@ -120,4 +128,14 @@ class TestInplacePredict:
return np.all(copied_predt == inplace_predt)
for i in range(10):
run_threaded_predict(X, rows, predict_csr)
run_threaded_predict(X, self.rows, predict_csr)
def test_base_margin(self):
booster = self.booster
base_margin = self.rng.randn(self.rows)
from_inplace = booster.inplace_predict(data=self.X, base_margin=base_margin)
dtrain = xgb.DMatrix(self.X, self.y, base_margin=base_margin)
from_dmatrix = booster.predict(dtrain)
np.testing.assert_allclose(from_dmatrix, from_inplace)