Implement categorical data support for SHAP. (#7053)

* Add CPU implementation.
* Update GPUTreeSHAP.
* Add GPU implementation by defining custom split condition.
This commit is contained in:
Jiaming Yuan 2021-06-25 19:02:46 +08:00 committed by GitHub
parent 663136aa08
commit 8fa32fdda2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 287 additions and 50 deletions

@ -1 +1 @@
Subproject commit 3310a30bb123a49ab12c58e03edc2479512d2f64
Subproject commit 5bba198a7c2b3298dc766740965a4dffa7d8ffa4

View File

@ -567,7 +567,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
*/
void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
unsigned unique_depth, PathElement* parent_unique_path,
bst_float parent_zero_fraction, bst_float parent_one_fraction,
int parent_feature_index, int condition,

View File

@ -87,9 +87,11 @@ struct BitFieldContainer {
BitFieldContainer() = default;
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
BitFieldContainer &operator=(BitFieldContainer const &that) = default;
BitFieldContainer &operator=(BitFieldContainer &&that) = default;
common::Span<value_type> Bits() { return bits_; }
common::Span<value_type const> Bits() const { return bits_; }
XGBOOST_DEVICE common::Span<value_type> Bits() { return bits_; }
XGBOOST_DEVICE common::Span<value_type const> Bits() const { return bits_; }
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.

View File

@ -42,6 +42,12 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t
return !s_cats.Check(cat);
}
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};
using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
} // namespace common

View File

@ -8,6 +8,7 @@
#include "device_helpers.cuh"
#include "quantile.h"
#include "timer.h"
#include "categorical.h"
namespace xgboost {
namespace common {
@ -17,11 +18,6 @@ using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry;
namespace detail {
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};
struct SketchUnique {
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
@ -122,7 +118,7 @@ class SketchContainer {
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
detail::IsCatOp{});
common::IsCatOp{});
timer_.Init(__func__);
}

View File

@ -415,6 +415,68 @@ class DeviceModel {
}
};
struct ShapSplitCondition {
ShapSplitCondition() = default;
XGBOOST_DEVICE
ShapSplitCondition(float feature_lower_bound, float feature_upper_bound,
bool is_missing_branch, common::CatBitField cats)
: feature_lower_bound(feature_lower_bound),
feature_upper_bound(feature_upper_bound),
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
assert(feature_lower_bound <= feature_upper_bound);
}
/*! Feature values >= lower and < upper flow down this path. */
float feature_lower_bound;
float feature_upper_bound;
/*! Feature value set to true flow down this path. */
common::CatBitField categories;
/*! Do missing values flow down this path? */
bool is_missing_branch;
// Does this instance flow down this path?
XGBOOST_DEVICE bool EvaluateSplit(float x) const {
// is nan
if (isnan(x)) {
return is_missing_branch;
}
if (categories.Size() != 0) {
auto cat = static_cast<uint32_t>(x);
return categories.Check(cat);
} else {
return x >= feature_lower_bound && x < feature_upper_bound;
}
}
// the &= op in bitfiled is per cuda thread, this one loops over the entire
// bitfield.
XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l,
common::CatBitField r) {
if (l.Data() == r.Data()) {
return l;
}
if (l.Size() > r.Size()) {
thrust::swap(l, r);
}
for (size_t i = 0; i < r.Bits().size(); ++i) {
l.Bits()[i] &= r.Bits()[i];
}
return l;
}
// Combine two split conditions on the same feature
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
// Combine duplicate features
if (categories.Size() != 0 || other.categories.Size() != 0) {
categories = Intersect(categories, other.categories);
} else {
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
}
is_missing_branch = is_missing_branch && other.is_missing_branch;
}
};
struct PathInfo {
int64_t leaf_position; // -1 not a leaf
size_t length;
@ -422,11 +484,12 @@ struct PathInfo {
};
// Transform model into path element form for GPUTreeShap
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
const gbm::GBTreeModel& model, size_t tree_limit,
int gpu_id) {
DeviceModel device_model;
device_model.Init(model, 0, tree_limit, gpu_id);
void ExtractPaths(
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> *paths,
DeviceModel *model, dh::device_vector<uint32_t> *path_categories,
int gpu_id) {
auto& device_model = *model;
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
auto d_nodes = device_model.nodes.ConstDeviceSpan();
@ -462,14 +525,45 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
paths->resize(path_segments.back());
auto d_paths = paths->data().get();
auto d_paths = dh::ToSpan(*paths);
auto d_info = info.data().get();
auto d_stats = device_model.stats.ConstDeviceSpan();
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
auto d_path_segments = path_segments.data().get();
auto d_split_types = device_model.split_types.ConstDeviceSpan();
auto d_cat_segments = device_model.categories_tree_segments.ConstDeviceSpan();
auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan();
size_t max_cat = 0;
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
common::IsCatOp{})) {
dh::PinnedMemory pinned;
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1);
auto max_elem_it = dh::MakeTransformIterator<size_t>(
dh::tbegin(d_cat_node_segments),
[] __device__(RegTree::Segment seg) { return seg.size; });
size_t max_cat_it =
thrust::max_element(thrust::device, max_elem_it,
max_elem_it + d_cat_node_segments.size()) -
max_elem_it;
dh::safe_cuda(cudaMemcpy(h_max_cat.data(),
d_cat_node_segments.data() + max_cat_it,
h_max_cat.size_bytes(), cudaMemcpyDeviceToHost));
max_cat = h_max_cat[0].size;
CHECK_GE(max_cat, 1);
path_categories->resize(max_cat * paths->size());
}
auto d_model_categories = device_model.categories.DeviceSpan();
common::Span<uint32_t> d_path_categories = dh::ToSpan(*path_categories);
dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
auto path_info = d_info[idx];
size_t tree_offset = d_tree_segments[path_info.tree_idx];
TreeView tree{0, path_info.tree_idx, d_nodes,
d_tree_segments, d_split_types, d_cat_segments,
d_cat_node_segments, d_model_categories};
int group = d_tree_group[path_info.tree_idx];
size_t child_idx = path_info.leaf_position;
auto child = d_nodes[child_idx];
@ -481,20 +575,38 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
double child_cover = d_stats[child_idx].sum_hess;
double parent_cover = d_stats[parent_idx].sum_hess;
double zero_fraction = child_cover / parent_cover;
auto parent = d_nodes[parent_idx];
auto parent = tree.d_tree[child.Parent()];
bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
d_paths[output_position--] = {
idx, parent.SplitIndex(), group, lower_bound,
upper_bound, is_missing_path, zero_fraction, v};
float lower_bound = -inf;
float upper_bound = inf;
common::CatBitField bits;
if (common::IsCat(tree.cats.split_type, child.Parent())) {
auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat);
size_t size = tree.cats.node_ptr[child.Parent()].size;
auto node_cats = tree.cats.categories.subspan(tree.cats.node_ptr[child.Parent()].beg, size);
SPAN_CHECK(path_cats.size() >= node_cats.size());
for (size_t i = 0; i < node_cats.size(); ++i) {
path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i];
}
bits = common::CatBitField{path_cats};
} else {
lower_bound = is_left_path ? -inf : parent.SplitCond();
upper_bound = is_left_path ? parent.SplitCond() : inf;
}
d_paths[output_position--] =
gpu_treeshap::PathElement<ShapSplitCondition>{
idx, parent.SplitIndex(),
group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits},
zero_fraction, v};
child_idx = parent_idx;
child = parent;
}
// Root node has feature -1
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v};
});
}
@ -696,11 +808,16 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float> const*,
std::vector<bst_float> const* tree_weights,
bool approximate, int,
unsigned) const override {
std::string not_implemented{"contribution is not implemented in GPU "
"predictor, use `cpu_predictor` instead."};
if (approximate) {
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
LOG(FATAL) << "Approximated " << not_implemented;
}
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
@ -718,16 +835,21 @@ class GPUPredictor : public xgboost::Predictor {
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
@ -746,11 +868,15 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end,
std::vector<bst_float> const*,
std::vector<bst_float> const* tree_weights,
bool approximate) const override {
std::string not_implemented{"contribution is not implemented in GPU "
"predictor, use `cpu_predictor` instead."};
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
LOG(FATAL) << "Approximated " << not_implemented;
}
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
@ -769,16 +895,21 @@ class GPUPredictor : public xgboost::Predictor {
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShapInteractions(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);

View File

@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
// recursive computation of SHAP values for a decision tree
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
bst_node_t node_index, unsigned unique_depth,
PathElement *parent_unique_path,
bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index,
@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();
}
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
node.RightChild() : node.LeftChild());
auto const &cats = this->GetCategoriesMatrix();
bst_node_t hot_index = predictor::GetNextNode<true, true>(
node, node_index, feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
const auto cold_index =
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
const bst_float w = this->Stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;

View File

@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) {
}
}
TEST(CpuPredictor, IterationRange) {
TestIterationRange("cpu_predictor");
}
TEST(CpuPredictor, ExternalMemory) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";

View File

@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) {
}
}
TEST(GPUPredictor, IterationRange) {
TestIterationRange("gpu_predictor");
}
TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}

View File

@ -281,4 +281,78 @@ void TestCategoricalPredictLeaf(StringView name) {
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
}
void TestIterationRange(std::string name) {
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3;
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)},
{"predictor", name}});
size_t kIters = 10;
for (size_t i = 0; i < kIters; ++i) {
learner->UpdateOneIter(i, dmat);
}
bool bound = false;
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
ASSERT_FALSE(bound);
HostDeviceVector<float> out_predt_sliced;
HostDeviceVector<float> out_predt_ranged;
// margin
{
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
false, false);
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
// SHAP
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
true, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
// SHAP interaction
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
false, false, true);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
false, true);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
// Leaf
{
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
false, false, false);
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
false, false);
auto const &h_sliced = out_predt_sliced.HostVector();
auto const &h_range = out_predt_ranged.HostVector();
ASSERT_EQ(h_sliced.size(), h_range.size());
ASSERT_EQ(h_sliced, h_range);
}
}
} // namespace xgboost

View File

@ -68,6 +68,8 @@ void TestPredictionWithLesserFeatures(std::string preditor_name);
void TestCategoricalPrediction(std::string name);
void TestCategoricalPredictLeaf(StringView name);
void TestIterationRange(std::string name);
} // namespace xgboost
#endif // XGBOOST_TEST_PREDICTOR_H_

View File

@ -5,7 +5,7 @@ import numpy as np
import xgboost as xgb
from xgboost.compat import PANDAS_INSTALLED
from hypothesis import given, strategies, assume, settings, note
from hypothesis import given, strategies, assume, settings
if PANDAS_INSTALLED:
from hypothesis.extra.pandas import column, data_frames, range_indexes
@ -275,6 +275,25 @@ class TestGPUPredict:
margin,
1e-3, 1e-3)
def test_shap_categorical(self):
X, y = tm.make_categorical(100, 20, 7, False)
Xy = xgb.DMatrix(X, y, enable_categorical=True)
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10)
booster.set_param({"predictor": "gpu_predictor"})
shap = booster.predict(Xy, pred_contribs=True)
margin = booster.predict(Xy, output_margin=True)
np.testing.assert_allclose(
np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3
)
booster.set_param({"predictor": "cpu_predictor"})
shap = booster.predict(Xy, pred_contribs=True)
margin = booster.predict(Xy, output_margin=True)
np.testing.assert_allclose(
np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3
)
def test_predict_leaf_basic(self):
gpu_leaf = run_predict_leaf('gpu_predictor')
cpu_leaf = run_predict_leaf('cpu_predictor')