SHAP values for feature contributions (#2438)

* SHAP values for feature contributions

* Fix commenting error

* New polynomial time SHAP value estimation algorithm

* Update API to support SHAP values

* Fix merge conflicts with updates in master

* Correct submodule hashes

* Fix variable sized stack allocation

* Make lint happy

* Add docs

* Fix typo

* Adjust tolerances

* Remove unneeded def

* Fixed cpp test setup

* Updated R API and cleaned up

* Fixed test typo
This commit is contained in:
Scott Lundberg 2017-10-12 12:35:51 -07:00 committed by GitHub
parent ff9180cd73
commit 78c4188cec
16 changed files with 369 additions and 143 deletions

View File

@ -128,6 +128,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#' It will use all the trees by default (\code{NULL} value).
#' @param predleaf whether predict leaf index instead.
#' @param predcontrib whether to return feature contributions to individual predictions instead (see Details).
#' @param approxcontrib whether to use a fast approximation for feature contributions (see Details).
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case. This option has no effect when \code{predleaf = TRUE}.
#' @param ... Parameters passed to \code{predict.xgb.Booster}
@ -148,10 +149,11 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#'
#' Setting \code{predcontrib = TRUE} allows to calculate contributions of each feature to
#' individual predictions. For "gblinear" booster, feature contributions are simply linear terms
#' (feature_beta * feature_value). For "gbtree" booster, feature contribution is calculated
#' as a sum of average contribution of that feature's split nodes across all trees to an
#' individual prediction, following the idea explained in
#' \url{http://blog.datadive.net/interpreting-random-forests/}.
#' (feature_beta * feature_value). For "gbtree" booster, feature contributions are SHAP
#' values (https://arxiv.org/abs/1706.06060) that sum to the difference between the expected output
#' of the model and the current prediction (where the hessian weights are used to compute the expectations).
#' Setting \code{approxcontrib = TRUE} approximates these values following the idea explained
#' in \url{http://blog.datadive.net/interpreting-random-forests/}.
#'
#' @return
#' For regression or binary classification, it returns a vector of length \code{nrows(newdata)}.
@ -195,7 +197,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#' # the result is an nsamples X (nfeatures + 1) matrix
#' pred_contr <- predict(bst, test$data, predcontrib = TRUE)
#' str(pred_contr)
#' # verify that contributions' sums are equal to log-odds of predictions (up to foat precision):
#' # verify that contributions' sums are equal to log-odds of predictions (up to float precision):
#' summary(rowSums(pred_contr) - qlogis(pred))
#' # for the 1st record, let's inspect its features that had non-zero contribution to prediction:
#' contr1 <- pred_contr[1,]
@ -258,7 +260,7 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
#' @rdname predict.xgb.Booster
#' @export
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE, ntreelimit = NULL,
predleaf = FALSE, predcontrib = FALSE, reshape = FALSE, ...) {
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, reshape = FALSE, ...) {
object <- xgb.Booster.complete(object, saveraw = FALSE)
if (!inherits(newdata, "xgb.DMatrix"))
@ -270,7 +272,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
if (ntreelimit < 0)
stop("ntreelimit cannot be negative")
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib)
option <- 0L + 1L * as.logical(outputmargin) + 2L * as.logical(predleaf) + 4L * as.logical(predcontrib) + 8L * as.logical(approxcontrib)
ret <- .Call(XGBoosterPredict_R, object$handle, newdata, option[1], as.integer(ntreelimit))

View File

@ -80,19 +80,26 @@ test_that("predict feature contributions works", {
expect_equal(dim(pred_contr), c(nrow(sparse_matrix), ncol(sparse_matrix) + 1))
expect_equal(colnames(pred_contr), c(colnames(sparse_matrix), "BIAS"))
pred <- predict(bst.Tree, sparse_matrix, outputmargin = TRUE)
expect_lt(max(abs(rowSums(pred_contr) - pred)), 1e-6)
expect_lt(max(abs(rowSums(pred_contr) - pred)), 1e-5)
# gbtree binary classifier (approximate method)
expect_error(pred_contr <- predict(bst.Tree, sparse_matrix, predcontrib = TRUE, approxcontrib = TRUE), regexp = NA)
expect_equal(dim(pred_contr), c(nrow(sparse_matrix), ncol(sparse_matrix) + 1))
expect_equal(colnames(pred_contr), c(colnames(sparse_matrix), "BIAS"))
pred <- predict(bst.Tree, sparse_matrix, outputmargin = TRUE)
expect_lt(max(abs(rowSums(pred_contr) - pred)), 1e-5)
# gblinear binary classifier
expect_error(pred_contr <- predict(bst.GLM, sparse_matrix, predcontrib = TRUE), regexp = NA)
expect_equal(dim(pred_contr), c(nrow(sparse_matrix), ncol(sparse_matrix) + 1))
expect_equal(colnames(pred_contr), c(colnames(sparse_matrix), "BIAS"))
pred <- predict(bst.GLM, sparse_matrix, outputmargin = TRUE)
expect_lt(max(abs(rowSums(pred_contr) - pred)), 2e-6)
expect_lt(max(abs(rowSums(pred_contr) - pred)), 1e-5)
# manual calculation of linear terms
coefs <- xgb.dump(bst.GLM)[-c(1,2,4)] %>% as.numeric
coefs <- c(coefs[-1], coefs[1]) # intercept must be the last
pred_contr_manual <- sweep(cbind(sparse_matrix, 1), 2, coefs, FUN="*")
expect_equal(as.numeric(pred_contr), as.numeric(pred_contr_manual), 2e-6)
expect_equal(as.numeric(pred_contr), as.numeric(pred_contr_manual), 1e-5)
# gbtree multiclass
pred <- predict(mbst.Tree, as.matrix(iris[, -5]), outputmargin = TRUE, reshape = TRUE)
@ -101,7 +108,7 @@ test_that("predict feature contributions works", {
expect_length(pred_contr, 3)
for (g in seq_along(pred_contr)) {
expect_equal(colnames(pred_contr[[g]]), c(colnames(iris[, -5]), "BIAS"))
expect_lt(max(abs(rowSums(pred_contr[[g]]) - pred[, g])), 2e-6)
expect_lt(max(abs(rowSums(pred_contr[[g]]) - pred[, g])), 1e-5)
}
# gblinear multiclass (set base_score = 0, which is base margin in multiclass)

View File

@ -115,10 +115,11 @@ class GradientBooster {
* \param out_contribs output vector to hold the contributions
* \param ntree_limit limit the number of trees used in prediction, when it equals 0, this means
* we do not limit number of trees
* \param approximate use a faster (inconsistent) approximation of SHAP values
*/
virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit = 0) = 0;
unsigned ntree_limit = 0, bool approximate = false) = 0;
/*!
* \brief dump the model in the requested format

View File

@ -104,13 +104,15 @@ class Learner : public rabit::Serializable {
* predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
* \param pred_contribs whether to only predict the feature contributions
* \param approx_contribs whether to approximate the feature contributions for speed
*/
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<bst_float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false,
bool pred_contribs = false) const = 0;
bool pred_contribs = false,
bool approx_contribs = false) const = 0;
/*!
* \brief Set additional attribute to the Booster.
* The property will be saved along the booster.

View File

@ -144,12 +144,14 @@ class Predictor {
* \param [in,out] out_contribs The output feature contribs.
* \param model Model to make predictions from.
* \param ntree_limit (Optional) The ntree limit.
* \param approximate Use fast approximate algorithm.
*/
virtual void PredictContribution(DMatrix* dmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit = 0) = 0;
unsigned ntree_limit = 0,
bool approximate = false) = 0;
/**
* \fn static Predictor* Predictor::Create(std::string name);

View File

@ -14,6 +14,7 @@
#include <string>
#include <cstring>
#include <algorithm>
#include <tuple>
#include "./base.h"
#include "./data.h"
#include "./logging.h"
@ -411,6 +412,20 @@ struct RTreeNodeStat {
int leaf_child_cnt;
};
// Used by TreeShap
// data we keep about our decision path
// note that pweight is included for convenience and is not tied with the other attributes
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
struct PathElement {
int feature_index;
bst_float zero_fraction;
bst_float one_fraction;
bst_float pweight;
PathElement() {}
PathElement(int i, bst_float z, bst_float o, bst_float w) :
feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};
/*!
* \brief define regression tree to be the most common tree model.
* This is the data structure used in xgboost's major tree models.
@ -482,13 +497,26 @@ class RegTree: public TreeModel<bst_float, RTreeNodeStat> {
*/
inline bst_float Predict(const FVec& feat, unsigned root_id = 0) const;
/*!
* \brief calculate the feature contributions for the given root
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions
*/
inline void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const;
inline void TreeShap(const RegTree::FVec& feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path, bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index) const;
/*!
* \brief calculate the approximate feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param root_id starting root index of the instance
* \param out_contribs output vector to hold the contributions
*/
inline void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const;
/*!
* \brief get next position of the tree given current pid
* \param pid Current node id.
@ -590,7 +618,7 @@ inline bst_float RegTree::FillNodeMeanValue(int nid) {
return result;
}
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
inline void RegTree::CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values.size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
@ -617,6 +645,154 @@ inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned
out_contribs[split_index] += leaf_value - node_value;
}
// extend our decision path with a fraction of one and zero extensions
inline void ExtendPath(PathElement *unique_path, unsigned unique_depth,
bst_float zero_fraction, bst_float one_fraction, int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
for (int i = unique_depth-1; i >= 0; i--) {
unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*(i+1)
/ static_cast<bst_float>(unique_depth+1);
unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth-i)
/ static_cast<bst_float>(unique_depth+1);
}
}
// undo a previous extension of the decision path
inline void UnwindPath(PathElement *unique_path, unsigned unique_depth, unsigned path_index) {
const bst_float one_fraction = unique_path[path_index].one_fraction;
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
bst_float next_one_portion = unique_path[unique_depth].pweight;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const bst_float tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion*(unique_depth+1)
/ static_cast<bst_float>((i+1)*one_fraction);
next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth-i)
/ static_cast<bst_float>(unique_depth+1);
} else {
unique_path[i].pweight = (unique_path[i].pweight*(unique_depth+1))
/ static_cast<bst_float>(zero_fraction*(unique_depth-i));
}
}
for (int i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i+1].feature_index;
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
}
}
// determine what the total permuation weight would be if
// we unwound a previous extension in the decision path
inline bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
unsigned path_index) {
const bst_float one_fraction = unique_path[path_index].one_fraction;
const bst_float zero_fraction = unique_path[path_index].zero_fraction;
bst_float next_one_portion = unique_path[unique_depth].pweight;
bst_float total = 0;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const bst_float tmp = next_one_portion*(unique_depth+1)
/ static_cast<bst_float>((i+1)*one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth-i)
/ static_cast<bst_float>(unique_depth+1));
} else {
total += (unique_path[i].pweight/zero_fraction)/((unique_depth-i)
/ static_cast<bst_float>(unique_depth+1));
}
}
return total;
}
// recursive computation of SHAP values for a decision tree
inline void RegTree::TreeShap(const RegTree::FVec& feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
PathElement *parent_unique_path, bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index) const {
const auto node = (*this)[node_index];
// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path+unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
const unsigned split_index = node.split_index();
// leaf node
if (node.is_leaf()) {
for (int i = 1; i <= unique_depth; ++i) {
const bst_float w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*node.leaf_value();
}
// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.is_missing(split_index)) {
hot_index = node.cdefault();
} else if (feat.fvalue(split_index) < node.split_cond()) {
hot_index = node.cleft();
} else {
hot_index = node.cright();
}
const unsigned cold_index = (hot_index == node.cleft() ? node.cright() : node.cleft());
const bst_float w = this->stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->stat(hot_index).sum_hess/w;
const bst_float cold_zero_fraction = this->stat(cold_index).sum_hess/w;
bst_float incoming_zero_fraction = 1;
bst_float incoming_one_fraction = 1;
// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
unsigned path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break;
}
if (path_index != unique_depth+1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}
TreeShap(feat, phi, hot_index, unique_depth+1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
TreeShap(feat, phi, cold_index, unique_depth+1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
}
}
inline void RegTree::CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
bst_float *out_contribs) const {
// find the expected value of the tree's predictions
bst_float base_value = 0.0;
bst_float total_cover = 0;
for (unsigned i = 0; i < (*this).param.num_nodes; ++i) {
const auto node = (*this)[i];
if (node.is_leaf()) {
const auto cover = this->stat(i).sum_hess;
base_value += cover*node.leaf_value();
total_cover += cover;
}
}
out_contribs[feat.size()] += base_value / total_cover;
// Preallocate space for the unique path data
const int maxd = this->MaxDepth(root_id)+1;
PathElement *unique_path_data = new PathElement[(maxd*(maxd+1))/2];
TreeShap(feat, out_contribs, root_id, 0, unique_path_data, 1, 1, -1);
delete[] unique_path_data;
}
/*! \brief get next position of the tree given current pid */
inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
bst_float split_value = (*this)[pid].split_cond();

View File

@ -990,7 +990,7 @@ class Booster(object):
return self.eval_set([(data, name)], iteration)
def predict(self, data, output_margin=False, ntree_limit=0, pred_leaf=False,
pred_contribs=False):
pred_contribs=False, approx_contribs=False):
"""
Predict with data.
@ -1018,9 +1018,12 @@ class Booster(object):
pred_contribs : bool
When this option is on, the output will be a matrix of (nsample, nfeats+1)
with each record indicating the feature contributions of all trees. The sum of
all feature contributions is equal to the prediction. Note that the bias is added
as the final column, on top of the regular features.
with each record indicating the feature contributions (SHAP values) for that
prediction. The sum of all feature contributions is equal to the prediction.
Note that the bias is added as the final column, on top of the regular features.
approx_contribs : bool
Approximate the contributions of each feature
Returns
-------
@ -1033,6 +1036,8 @@ class Booster(object):
option_mask |= 0x02
if pred_contribs:
option_mask |= 0x04
if approx_contribs:
option_mask |= 0x08
self._validate_features(data)

View File

@ -758,7 +758,8 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
(option_mask & 1) != 0,
&preds, ntree_limit,
(option_mask & 2) != 0,
(option_mask & 4) != 0);
(option_mask & 4) != 0,
(option_mask & 8) != 0);
*out_result = dmlc::BeginPtr(preds);
*len = static_cast<xgboost::bst_ulong>(preds.size());
API_END();

View File

@ -224,7 +224,7 @@ class GBLinear : public GradientBooster {
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit) override {
unsigned ntree_limit, bool approximate) override {
if (model.weight.size() == 0) {
model.InitModel();
}

View File

@ -233,8 +233,8 @@ class GBTree : public GradientBooster {
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit) override {
predictor->PredictContribution(p_fmat, out_contribs, model_, ntree_limit);
unsigned ntree_limit, bool approximate) override {
predictor->PredictContribution(p_fmat, out_contribs, model_, ntree_limit, approximate);
}
std::vector<std::string> DumpModel(const FeatureMap& fmap,

View File

@ -433,9 +433,9 @@ class LearnerImpl : public Learner {
void Predict(DMatrix* data, bool output_margin,
std::vector<bst_float>* out_preds, unsigned ntree_limit,
bool pred_leaf, bool pred_contribs) const override {
bool pred_leaf, bool pred_contribs, bool approx_contribs) const override {
if (pred_contribs) {
gbm_->PredictContribution(data, out_preds, ntree_limit);
gbm_->PredictContribution(data, out_preds, ntree_limit, approx_contribs);
} else if (pred_leaf) {
gbm_->PredictLeaf(data, out_preds, ntree_limit);
} else {

View File

@ -206,9 +206,9 @@ class CPUPredictor : public Predictor {
}
}
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit) override {
void PredictContribution(DMatrix* p_fmat, std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
bool approximate) override {
const int nthread = omp_get_max_threads();
InitThreadTemp(nthread, model.param.num_feature);
const MetaInfo& info = p_fmat->info();
@ -225,11 +225,13 @@ class CPUPredictor : public Predictor {
// make sure contributions is zeroed, we could be reusing a previously
// allocated one
std::fill(contribs.begin(), contribs.end(), 0);
if (approximate) {
// initialize tree node mean values
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ntree_limit; ++i) {
model.trees[i]->FillNodeMeanValues();
}
}
// start collecting the contributions
dmlc::DataIter<RowBatch>* iter = p_fmat->RowIterator();
const std::vector<bst_float>& base_margin = info.base_margin;
@ -253,7 +255,11 @@ class CPUPredictor : public Predictor {
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
model.trees[j]->CalculateContributions(feats, root_id, p_contribs);
} else {
model.trees[j]->CalculateContributionsApprox(feats, root_id, p_contribs);
}
}
feats.Drop(batch[i]);
// add base margin to BIAS

View File

@ -384,9 +384,10 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit) override {
unsigned ntree_limit,
bool approximate) override {
cpu_predictor->PredictContribution(p_fmat, out_contribs, model,
ntree_limit);
ntree_limit, approximate);
}
void Init(const std::vector<std::pair<std::string, std::string>>& cfg,

View File

@ -12,6 +12,7 @@ TEST(cpu_predictor, Test) {
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
trees.back()->InitModel();
(*trees.back())[0].set_leaf(1.5f);
(*trees.back()).stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
model.CommitModel(std::move(trees), 0);
model.param.num_output_group = 1;
@ -50,5 +51,11 @@ TEST(cpu_predictor, Test) {
for (int i = 0; i < out_contribution.size(); i++) {
ASSERT_EQ(out_contribution[i], 1.5);
}
// Test predict contribution (approximate method)
cpu_predictor->PredictContribution(dmat.get(), &out_contribution, model, true);
for (int i = 0; i < out_contribution.size(); i++) {
ASSERT_EQ(out_contribution[i], 1.5);
}
}
} // namespace xgboost

View File

@ -19,6 +19,7 @@ TEST(gpu_predictor, Test) {
trees.push_back(std::unique_ptr<RegTree>());
trees.back()->InitModel();
(*trees.back())[0].set_leaf(1.5f);
(*trees.back()).stat(0).sum_hess = 1.0f;
gbm::GBTreeModel model(0.5);
model.CommitModel(std::move(trees), 0);
model.param.num_output_group = 1;

View File

@ -291,3 +291,18 @@ def test_contributions():
for max_depth, num_rounds in itertools.product(range(0, 3), range(1, 5)):
yield test_fn, max_depth, num_rounds
# check that we get the right SHAP values for a basic AND example
# (https://arxiv.org/abs/1706.06060)
X = np.zeros((4, 2))
X[0, :] = 1
X[1, 0] = 1
X[2, 1] = 1
y = np.zeros(4)
y[0] = 1
param = {"max_depth": 2, "base_score": 0.0, "eta": 1.0, "lambda": 0}
bst = xgb.train(param, xgb.DMatrix(X, label=y), 1)
out = bst.predict(xgb.DMatrix(X[0:1, :]), pred_contribs=True)
assert out[0, 0] == 0.375
assert out[0, 1] == 0.375
assert out[0, 2] == 0.25