Add SHAP interaction effects, fix minor bug, and add cox loss (#3043)
* Add interaction effects and cox loss * Minimize whitespace changes * Cox loss now no longer needs a pre-sorted dataset. * Address code review comments * Remove mem check, rename to pred_interactions, include bias * Make lint happy * More lint fixes * Fix cox loss indexing * Fix main effects and tests * Fix lint * Use half interaction values on the off-diagonals * Fix lint again
This commit is contained in:
committed by
Vadim Khotilovich
parent
077abb35cd
commit
d878c36c84
@@ -197,6 +197,90 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
||||
.describe("Possion regression for count data.")
|
||||
.set_body([]() { return new PoissonRegression(); });
|
||||
|
||||
// cox regression for survival data (negative values mean they are censored)
|
||||
class CoxRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {}
|
||||
void GetGradient(const std::vector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided";
|
||||
out_gpair->resize(preds.size());
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort();
|
||||
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
|
||||
|
||||
// pre-compute a sum
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
exp_p_sum += std::exp(preds[label_order[i]]);
|
||||
}
|
||||
|
||||
// start calculating grad and hess
|
||||
double r_k = 0;
|
||||
double s_k = 0;
|
||||
double last_exp_p = 0.0;
|
||||
double last_abs_y = 0.0;
|
||||
double accumulated_sum = 0;
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
const size_t ind = label_order[i];
|
||||
const double p = preds[ind];
|
||||
const double exp_p = std::exp(p);
|
||||
const double w = info.GetWeight(ind);
|
||||
const double y = info.labels[ind];
|
||||
const double abs_y = std::abs(y);
|
||||
|
||||
// only update the denominator after we move forward in time (labels are sorted)
|
||||
// this is Breslow's method for ties
|
||||
accumulated_sum += last_exp_p;
|
||||
if (last_abs_y < abs_y) {
|
||||
exp_p_sum -= accumulated_sum;
|
||||
accumulated_sum = 0;
|
||||
} else {
|
||||
CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " <<
|
||||
"MetaInfo::LabelArgsort failed!";
|
||||
}
|
||||
|
||||
if (y > 0) {
|
||||
r_k += 1.0/exp_p_sum;
|
||||
s_k += 1.0/(exp_p_sum*exp_p_sum);
|
||||
}
|
||||
|
||||
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
|
||||
const double hess = exp_p*r_k - exp_p*exp_p * s_k;
|
||||
out_gpair->at(ind) = bst_gpair(grad * w, hess * w);
|
||||
|
||||
last_abs_y = abs_y;
|
||||
last_exp_p = exp_p;
|
||||
}
|
||||
}
|
||||
void PredTransform(std::vector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = *io_preds;
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
void EvalTransform(std::vector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric(void) const override {
|
||||
return "cox-nloglik";
|
||||
}
|
||||
};
|
||||
|
||||
// register the objective function
|
||||
XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
|
||||
.describe("Cox regression for censored survival data (negative labels are considered censored).")
|
||||
.set_body([]() { return new CoxRegression(); });
|
||||
|
||||
// gamma regression
|
||||
class GammaRegression : public ObjFunction {
|
||||
public:
|
||||
|
||||
Reference in New Issue
Block a user