Fix approximated predict contribution. (#6811)
This commit is contained in:
parent
0cced530ea
commit
7e06c81894
@ -740,15 +740,17 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
*
|
||||
* \param handle Booster handle
|
||||
* \param dmat DMatrix handle
|
||||
* \param c_json_config String encoded predict configuration in JSON format.
|
||||
* \param c_json_config String encoded predict configuration in JSON format, with
|
||||
* following available fields in the JSON object:
|
||||
*
|
||||
* "type": [0, 5]
|
||||
* "type": [0, 6]
|
||||
* 0: normal prediction
|
||||
* 1: output margin
|
||||
* 2: predict contribution
|
||||
* 3: predict approxmated contribution
|
||||
* 3: predict approximated contribution
|
||||
* 4: predict feature interaction
|
||||
* 5: predict leaf
|
||||
* 5: predict approximated feature interaction
|
||||
* 6: predict leaf
|
||||
* "training": bool
|
||||
* Whether the prediction function is used as part of a training loop. **Not used
|
||||
* for inplace prediction**.
|
||||
@ -764,7 +766,7 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
|
||||
* "iteration_begin": int
|
||||
* Beginning iteration of prediction.
|
||||
* "iteration_end": int
|
||||
* End iteration of prediction. Set to 0 this will become the size of tree model.
|
||||
* End iteration of prediction. Set to 0 this will become the size of tree model (all the trees).
|
||||
* "strict_shape": bool
|
||||
* Whether should we reshape the output with stricter rules. If set to true,
|
||||
* normal/margin/contrib/interaction predict will output consistent shape
|
||||
|
||||
@ -36,7 +36,8 @@ enum class PredictionType : std::uint8_t { // NOLINT
|
||||
kContribution = 2,
|
||||
kApproxContribution = 3,
|
||||
kInteraction = 4,
|
||||
kLeaf = 5
|
||||
kApproxInteraction = 5,
|
||||
kLeaf = 6
|
||||
};
|
||||
|
||||
/*! \brief entry to to easily hold returning information */
|
||||
|
||||
@ -1647,7 +1647,9 @@ class Booster(object):
|
||||
prediction. Note the final column is the bias term.
|
||||
|
||||
approx_contribs :
|
||||
Approximate the contributions of each feature
|
||||
Approximate the contributions of each feature. Used when ``pred_contribs`` or
|
||||
``pred_interactions`` is set to True. Changing the default of this parameter
|
||||
(False) is not recommended.
|
||||
|
||||
pred_interactions :
|
||||
When this is True the output will be a matrix of size (nsample,
|
||||
@ -1716,9 +1718,9 @@ class Booster(object):
|
||||
if pred_contribs:
|
||||
assign_type(2 if not approx_contribs else 3)
|
||||
if pred_interactions:
|
||||
assign_type(4)
|
||||
assign_type(4 if not approx_contribs else 5)
|
||||
if pred_leaf:
|
||||
assign_type(5)
|
||||
assign_type(6)
|
||||
preds = ctypes.POINTER(ctypes.c_float)()
|
||||
shape = ctypes.POINTER(c_bst_ulong)()
|
||||
dims = c_bst_ulong()
|
||||
|
||||
@ -651,13 +651,17 @@ XGB_DLL int XGBoosterPredictFromDMatrix(BoosterHandle handle,
|
||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||
auto iteration_begin = get<Integer const>(config["iteration_begin"]);
|
||||
auto iteration_end = get<Integer const>(config["iteration_end"]);
|
||||
learner->Predict(
|
||||
*static_cast<std::shared_ptr<DMatrix> *>(dmat),
|
||||
type == PredictionType::kMargin, &entry.predictions, iteration_begin,
|
||||
iteration_end, get<Boolean const>(config["training"]),
|
||||
type == PredictionType::kLeaf, type == PredictionType::kContribution,
|
||||
type == PredictionType::kApproxContribution,
|
||||
type == PredictionType::kInteraction);
|
||||
bool approximate = type == PredictionType::kApproxContribution ||
|
||||
type == PredictionType::kApproxInteraction;
|
||||
bool contribs = type == PredictionType::kContribution ||
|
||||
type == PredictionType::kApproxContribution;
|
||||
bool interactions = type == PredictionType::kInteraction ||
|
||||
type == PredictionType::kApproxInteraction;
|
||||
bool training = get<Boolean const>(config["training"]);
|
||||
learner->Predict(p_m, type == PredictionType::kMargin, &entry.predictions,
|
||||
iteration_begin, iteration_end, training,
|
||||
type == PredictionType::kLeaf, contribs, approximate,
|
||||
interactions);
|
||||
*out_result = dmlc::BeginPtr(entry.predictions.ConstHostVector());
|
||||
auto &shape = learner->GetThreadLocal().prediction_shape;
|
||||
auto chunksize = p_m->Info().num_row_ == 0 ? 0 : entry.predictions.Size() / p_m->Info().num_row_;
|
||||
|
||||
@ -56,7 +56,6 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
}
|
||||
case PredictionType::kApproxContribution:
|
||||
case PredictionType::kContribution: {
|
||||
auto groups = chunksize / (cols + 1);
|
||||
if (groups == 1 && !strict_shape) {
|
||||
*out_dim = 2;
|
||||
shape.resize(*out_dim);
|
||||
@ -71,6 +70,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
||||
}
|
||||
break;
|
||||
}
|
||||
case PredictionType::kApproxInteraction:
|
||||
case PredictionType::kInteraction: {
|
||||
if (groups == 1 && !strict_shape) {
|
||||
*out_dim = 3;
|
||||
|
||||
@ -98,6 +98,27 @@ def test_predict_shape():
|
||||
assert len(contrib.shape) == 3
|
||||
assert contrib.shape[1] == 1
|
||||
|
||||
contrib = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_contribs=True, approx_contribs=True
|
||||
)
|
||||
assert len(contrib.shape) == 2
|
||||
assert contrib.shape[1] == X.shape[1] + 1
|
||||
|
||||
interaction = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True
|
||||
)
|
||||
assert len(interaction.shape) == 3
|
||||
assert interaction.shape[1] == X.shape[1] + 1
|
||||
assert interaction.shape[2] == X.shape[1] + 1
|
||||
|
||||
interaction = reg.get_booster().predict(
|
||||
xgb.DMatrix(X), pred_interactions=True, approx_contribs=True, strict_shape=True
|
||||
)
|
||||
assert len(interaction.shape) == 4
|
||||
assert interaction.shape[1] == 1
|
||||
assert interaction.shape[2] == X.shape[1] + 1
|
||||
assert interaction.shape[3] == X.shape[1] + 1
|
||||
|
||||
|
||||
class TestInplacePredict:
|
||||
'''Tests for running inplace prediction'''
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user