Add SHAP interaction effects, fix minor bug, and add cox loss (#3043)

* Add interaction effects and cox loss

* Minimize whitespace changes

* Cox loss now no longer needs a pre-sorted dataset.

* Address code review comments

* Remove mem check, rename to pred_interactions, include bias

* Make lint happy

* More lint fixes

* Fix cox loss indexing

* Fix main effects and tests

* Fix lint

* Use half interaction values on the off-diagonals

* Fix lint again
This commit is contained in:
Scott Lundberg
2018-02-07 18:38:01 -08:00
committed by Vadim Khotilovich
parent 077abb35cd
commit d878c36c84
19 changed files with 638 additions and 125 deletions

View File

@@ -304,6 +304,52 @@ struct EvalMAP : public EvalRankList {
}
};
/*! \brief Cox: Partial likelihood of the Cox proportional hazards model */
struct EvalCox : public Metric {
public:
EvalCox() {}
bst_float Eval(const std::vector<bst_float> &preds,
const MetaInfo &info,
bool distributed) const override {
CHECK(!distributed) << "Cox metric does not support distributed evaluation";
using namespace std; // NOLINT(*)
const bst_omp_uint ndata = static_cast<bst_omp_uint>(info.labels.size());
const std::vector<size_t> &label_order = info.LabelAbsSort();
// pre-compute a sum for the denominator
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
for (omp_ulong i = 0; i < ndata; ++i) {
exp_p_sum += preds[i];
}
double out = 0;
double accumulated_sum = 0;
bst_omp_uint num_events = 0;
for (bst_omp_uint i = 0; i < ndata; ++i) {
const size_t ind = label_order[i];
const auto label = info.labels[ind];
if (label > 0) {
out -= log(preds[ind]) - log(exp_p_sum);
++num_events;
}
// only update the denominator after we move forward in time (labels are sorted)
accumulated_sum += preds[ind];
if (i == ndata - 1 || std::abs(label) < std::abs(info.labels[label_order[i + 1]])) {
exp_p_sum -= accumulated_sum;
accumulated_sum = 0;
}
}
return out/num_events; // normalize by the number of events
}
const char* Name() const override {
return "cox-nloglik";
}
};
XGBOOST_REGISTER_METRIC(AMS, "ams")
.describe("AMS metric for higgs.")
.set_body([](const char* param) { return new EvalAMS(param); });
@@ -323,5 +369,9 @@ XGBOOST_REGISTER_METRIC(NDCG, "ndcg")
XGBOOST_REGISTER_METRIC(MAP, "map")
.describe("map@k for rank.")
.set_body([](const char* param) { return new EvalMAP(param); });
XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik")
.describe("Negative log partial likelihood of Cox proportioanl hazards model.")
.set_body([](const char* param) { return new EvalCox(); });
} // namespace metric
} // namespace xgboost