Fix approximated predict contribution. (#6811)

This commit is contained in:
Jiaming Yuan
2021-04-03 02:15:03 +08:00
committed by GitHub
parent 0cced530ea
commit 7e06c81894
6 changed files with 47 additions and 17 deletions

View File

@@ -98,6 +98,27 @@ def test_predict_shape():
assert len(contrib.shape) == 3
assert contrib.shape[1] == 1
contrib = reg.get_booster().predict(
xgb.DMatrix(X), pred_contribs=True, approx_contribs=True
)
assert len(contrib.shape) == 2
assert contrib.shape[1] == X.shape[1] + 1
interaction = reg.get_booster().predict(
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True
)
assert len(interaction.shape) == 3
assert interaction.shape[1] == X.shape[1] + 1
assert interaction.shape[2] == X.shape[1] + 1
interaction = reg.get_booster().predict(
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True, strict_shape=True
)
assert len(interaction.shape) == 4
assert interaction.shape[1] == 1
assert interaction.shape[2] == X.shape[1] + 1
assert interaction.shape[3] == X.shape[1] + 1
class TestInplacePredict:
'''Tests for running inplace prediction'''