Fix approximated predict contribution. (#6811)

This commit is contained in:
Jiaming Yuan 2021-04-03 02:15:03 +08:00 committed by GitHub
parent 0cced530ea
commit 7e06c81894
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 47 additions and 17 deletions

View File

@ -740,15 +740,17 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
* *
* \param handle Booster handle * \param handle Booster handle
* \param dmat DMatrix handle * \param dmat DMatrix handle
* \param c_json_config String encoded predict configuration in JSON format. * \param c_json_config String encoded predict configuration in JSON format, with
* following available fields in the JSON object:
* *
* "type": [0, 5] * "type": [0, 6]
* 0: normal prediction * 0: normal prediction
* 1: output margin * 1: output margin
* 2: predict contribution * 2: predict contribution
* 3: predict approxmated contribution * 3: predict approximated contribution
* 4: predict feature interaction * 4: predict feature interaction
* 5: predict leaf * 5: predict approximated feature interaction
* 6: predict leaf
* "training": bool * "training": bool
* Whether the prediction function is used as part of a training loop. **Not used * Whether the prediction function is used as part of a training loop. **Not used
* for inplace prediction**. * for inplace prediction**.
@ -764,7 +766,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
* "iteration_begin": int * "iteration_begin": int
* Beginning iteration of prediction. * Beginning iteration of prediction.
* "iteration_end": int * "iteration_end": int
* End iteration of prediction. Set to 0 this will become the size of tree model. * End iteration of prediction. Set to 0 this will become the size of tree model (all the trees).
* "strict_shape": bool * "strict_shape": bool
* Whether should we reshape the output with stricter rules. If set to true, * Whether should we reshape the output with stricter rules. If set to true,
* normal/margin/contrib/interaction predict will output consistent shape * normal/margin/contrib/interaction predict will output consistent shape

View File

@ -36,7 +36,8 @@ enum class PredictionType : std::uint8_t { // NOLINT
kContribution = 2, kContribution = 2,
kApproxContribution = 3, kApproxContribution = 3,
kInteraction = 4, kInteraction = 4,
kLeaf = 5 kApproxInteraction = 5,
kLeaf = 6
}; };
/*! \brief entry to to easily hold returning information */ /*! \brief entry to to easily hold returning information */

View File

@ -1647,7 +1647,9 @@ class Booster(object):
prediction. Note the final column is the bias term. prediction. Note the final column is the bias term.
approx_contribs : approx_contribs :
Approximate the contributions of each feature Approximate the contributions of each feature. Used when ``pred_contribs`` or
``pred_interactions`` is set to True. Changing the default of this parameter
(False) is not recommended.
pred_interactions : pred_interactions :
When this is True the output will be a matrix of size (nsample, When this is True the output will be a matrix of size (nsample,
@ -1716,9 +1718,9 @@ class Booster(object):
if pred_contribs: if pred_contribs:
assign_type(2 if not approx_contribs else 3) assign_type(2 if not approx_contribs else 3)
if pred_interactions: if pred_interactions:
assign_type(4) assign_type(4 if not approx_contribs else 5)
if pred_leaf: if pred_leaf:
assign_type(5) assign_type(6)
preds = ctypes.POINTER(ctypes.c_float)() preds = ctypes.POINTER(ctypes.c_float)()
shape = ctypes.POINTER(c_bst_ulong)() shape = ctypes.POINTER(c_bst_ulong)()
dims = c_bst_ulong() dims = c_bst_ulong()

View File

@ -651,13 +651,17 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
auto type = PredictionType(get<Integer const>(config["type"])); auto type = PredictionType(get<Integer const>(config["type"]));
auto iteration_begin = get<Integer const>(config["iteration_begin"]); auto iteration_begin = get<Integer const>(config["iteration_begin"]);
auto iteration_end = get<Integer const>(config["iteration_end"]); auto iteration_end = get<Integer const>(config["iteration_end"]);
learner->Predict( bool approximate = type == PredictionType::kApproxContribution ||
*static_cast<std::shared_ptr<DMatrix> *>(dmat), type == PredictionType::kApproxInteraction;
type == PredictionType::kMargin, &entry.predictions, iteration_begin, bool contribs = type == PredictionType::kContribution ||
iteration_end, get<Boolean const>(config["training"]), type == PredictionType::kApproxContribution;
type == PredictionType::kLeaf, type == PredictionType::kContribution, bool interactions = type == PredictionType::kInteraction ||
type == PredictionType::kApproxContribution, type == PredictionType::kApproxInteraction;
type == PredictionType::kInteraction); bool training = get<Boolean const>(config["training"]);
learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions,
iteration_begin, iteration_end, training,
type == PredictionType::kLeaf, contribs, approximate,
interactions);
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector()); *out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
auto &shape = learner->GetThreadLocal().prediction_shape; auto &shape = learner->GetThreadLocal().prediction_shape;
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_; auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;

View File

@ -56,7 +56,6 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
} }
case PredictionType::kApproxContribution: case PredictionType::kApproxContribution:
case PredictionType::kContribution: { case PredictionType::kContribution: {
auto groups = chunksize / (cols + 1);
if (groups == 1 && !strict_shape) { if (groups == 1 && !strict_shape) {
*out_dim = 2; *out_dim = 2;
shape.resize(*out_dim); shape.resize(*out_dim);
@ -71,6 +70,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
} }
break; break;
} }
case PredictionType::kApproxInteraction:
case PredictionType::kInteraction: { case PredictionType::kInteraction: {
if (groups == 1 && !strict_shape) { if (groups == 1 && !strict_shape) {
*out_dim = 3; *out_dim = 3;

View File

@ -98,6 +98,27 @@ def test_predict_shape():
assert len(contrib.shape) == 3 assert len(contrib.shape) == 3
assert contrib.shape[1] == 1 assert contrib.shape[1] == 1
contrib = reg.get_booster().predict(
xgb.DMatrix(X), pred_contribs=True, approx_contribs=True
)
assert len(contrib.shape) == 2
assert contrib.shape[1] == X.shape[1] + 1
interaction = reg.get_booster().predict(
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True
)
assert len(interaction.shape) == 3
assert interaction.shape[1] == X.shape[1] + 1
assert interaction.shape[2] == X.shape[1] + 1
interaction = reg.get_booster().predict(
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True, strict_shape=True
)
assert len(interaction.shape) == 4
assert interaction.shape[1] == 1
assert interaction.shape[2] == X.shape[1] + 1
assert interaction.shape[3] == X.shape[1] + 1
class TestInplacePredict: class TestInplacePredict:
'''Tests for running inplace prediction''' '''Tests for running inplace prediction'''