Fix approximated predict contribution. (#6811)
This commit is contained in:
@@ -98,6 +98,27 @@ def test_predict_shape():
|
||||
assert len(contrib.shape) == 3
|
||||
assert contrib.shape[1] == 1
|
||||
|
||||
contrib = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_contribs=True, approx_contribs=True
|
||||
)
|
||||
assert len(contrib.shape) == 2
|
||||
assert contrib.shape[1] == X.shape[1] + 1
|
||||
|
||||
interaction = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True
|
||||
)
|
||||
assert len(interaction.shape) == 3
|
||||
assert interaction.shape[1] == X.shape[1] + 1
|
||||
assert interaction.shape[2] == X.shape[1] + 1
|
||||
|
||||
interaction = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True, strict_shape=True
|
||||
)
|
||||
assert len(interaction.shape) == 4
|
||||
assert interaction.shape[1] == 1
|
||||
assert interaction.shape[2] == X.shape[1] + 1
|
||||
assert interaction.shape[3] == X.shape[1] + 1
|
||||
|
||||
|
||||
class TestInplacePredict:
|
||||
'''Tests for running inplace prediction'''
|
||||
|
||||
Reference in New Issue
Block a user