Implement categorical data support for SHAP. (#7053)
* Add CPU implementation. * Update GPUTreeSHAP. * Add GPU implementation by defining custom split condition.
This commit is contained in:
parent
663136aa08
commit
8fa32fdda2
@ -1 +1 @@
|
|||||||
Subproject commit 3310a30bb123a49ab12c58e03edc2479512d2f64
|
Subproject commit 5bba198a7c2b3298dc766740965a4dffa7d8ffa4
|
||||||
@ -567,7 +567,7 @@ class RegTree : public Model {
|
|||||||
* \param condition_feature the index of the feature to fix
|
* \param condition_feature the index of the feature to fix
|
||||||
* \param condition_fraction what fraction of the current weight matches our conditioning feature
|
* \param condition_fraction what fraction of the current weight matches our conditioning feature
|
||||||
*/
|
*/
|
||||||
void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
|
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
|
||||||
unsigned unique_depth, PathElement* parent_unique_path,
|
unsigned unique_depth, PathElement* parent_unique_path,
|
||||||
bst_float parent_zero_fraction, bst_float parent_one_fraction,
|
bst_float parent_zero_fraction, bst_float parent_one_fraction,
|
||||||
int parent_feature_index, int condition,
|
int parent_feature_index, int condition,
|
||||||
|
|||||||
@ -87,9 +87,11 @@ struct BitFieldContainer {
|
|||||||
BitFieldContainer() = default;
|
BitFieldContainer() = default;
|
||||||
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
|
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
|
||||||
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
|
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
|
||||||
|
BitFieldContainer &operator=(BitFieldContainer const &that) = default;
|
||||||
|
BitFieldContainer &operator=(BitFieldContainer &&that) = default;
|
||||||
|
|
||||||
common::Span<value_type> Bits() { return bits_; }
|
XGBOOST_DEVICE common::Span<value_type> Bits() { return bits_; }
|
||||||
common::Span<value_type const> Bits() const { return bits_; }
|
XGBOOST_DEVICE common::Span<value_type const> Bits() const { return bits_; }
|
||||||
|
|
||||||
/*\brief Compute the size of needed memory allocation. The returned value is in terms
|
/*\brief Compute the size of needed memory allocation. The returned value is in terms
|
||||||
* of number of elements with `BitFieldContainer::value_type'.
|
* of number of elements with `BitFieldContainer::value_type'.
|
||||||
|
|||||||
@ -42,6 +42,12 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t
|
|||||||
return !s_cats.Check(cat);
|
return !s_cats.Check(cat);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct IsCatOp {
|
||||||
|
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
||||||
|
return ft == FeatureType::kCategorical;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
using CatBitField = LBitField32;
|
using CatBitField = LBitField32;
|
||||||
using KCatBitField = CLBitField32;
|
using KCatBitField = CLBitField32;
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include "device_helpers.cuh"
|
#include "device_helpers.cuh"
|
||||||
#include "quantile.h"
|
#include "quantile.h"
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
#include "categorical.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -17,11 +18,6 @@ using WQSketch = WQuantileSketch<bst_float, bst_float>;
|
|||||||
using SketchEntry = WQSketch::Entry;
|
using SketchEntry = WQSketch::Entry;
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
struct IsCatOp {
|
|
||||||
XGBOOST_DEVICE bool operator()(FeatureType ft) {
|
|
||||||
return ft == FeatureType::kCategorical;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
struct SketchUnique {
|
struct SketchUnique {
|
||||||
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
|
||||||
return a.value - b.value == 0;
|
return a.value - b.value == 0;
|
||||||
@ -122,7 +118,7 @@ class SketchContainer {
|
|||||||
has_categorical_ =
|
has_categorical_ =
|
||||||
!d_feature_types.empty() &&
|
!d_feature_types.empty() &&
|
||||||
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
|
||||||
detail::IsCatOp{});
|
common::IsCatOp{});
|
||||||
|
|
||||||
timer_.Init(__func__);
|
timer_.Init(__func__);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -415,6 +415,68 @@ class DeviceModel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct ShapSplitCondition {
|
||||||
|
ShapSplitCondition() = default;
|
||||||
|
XGBOOST_DEVICE
|
||||||
|
ShapSplitCondition(float feature_lower_bound, float feature_upper_bound,
|
||||||
|
bool is_missing_branch, common::CatBitField cats)
|
||||||
|
: feature_lower_bound(feature_lower_bound),
|
||||||
|
feature_upper_bound(feature_upper_bound),
|
||||||
|
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
|
||||||
|
assert(feature_lower_bound <= feature_upper_bound);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*! Feature values >= lower and < upper flow down this path. */
|
||||||
|
float feature_lower_bound;
|
||||||
|
float feature_upper_bound;
|
||||||
|
/*! Feature value set to true flow down this path. */
|
||||||
|
common::CatBitField categories;
|
||||||
|
/*! Do missing values flow down this path? */
|
||||||
|
bool is_missing_branch;
|
||||||
|
|
||||||
|
// Does this instance flow down this path?
|
||||||
|
XGBOOST_DEVICE bool EvaluateSplit(float x) const {
|
||||||
|
// is nan
|
||||||
|
if (isnan(x)) {
|
||||||
|
return is_missing_branch;
|
||||||
|
}
|
||||||
|
if (categories.Size() != 0) {
|
||||||
|
auto cat = static_cast<uint32_t>(x);
|
||||||
|
return categories.Check(cat);
|
||||||
|
} else {
|
||||||
|
return x >= feature_lower_bound && x < feature_upper_bound;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the &= op in bitfiled is per cuda thread, this one loops over the entire
|
||||||
|
// bitfield.
|
||||||
|
XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l,
|
||||||
|
common::CatBitField r) {
|
||||||
|
if (l.Data() == r.Data()) {
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
if (l.Size() > r.Size()) {
|
||||||
|
thrust::swap(l, r);
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < r.Bits().size(); ++i) {
|
||||||
|
l.Bits()[i] &= r.Bits()[i];
|
||||||
|
}
|
||||||
|
return l;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine two split conditions on the same feature
|
||||||
|
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
|
||||||
|
// Combine duplicate features
|
||||||
|
if (categories.Size() != 0 || other.categories.Size() != 0) {
|
||||||
|
categories = Intersect(categories, other.categories);
|
||||||
|
} else {
|
||||||
|
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
|
||||||
|
feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
|
||||||
|
}
|
||||||
|
is_missing_branch = is_missing_branch && other.is_missing_branch;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
struct PathInfo {
|
struct PathInfo {
|
||||||
int64_t leaf_position; // -1 not a leaf
|
int64_t leaf_position; // -1 not a leaf
|
||||||
size_t length;
|
size_t length;
|
||||||
@ -422,11 +484,12 @@ struct PathInfo {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Transform model into path element form for GPUTreeShap
|
// Transform model into path element form for GPUTreeShap
|
||||||
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
void ExtractPaths(
|
||||||
const gbm::GBTreeModel& model, size_t tree_limit,
|
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> *paths,
|
||||||
|
DeviceModel *model, dh::device_vector<uint32_t> *path_categories,
|
||||||
int gpu_id) {
|
int gpu_id) {
|
||||||
DeviceModel device_model;
|
auto& device_model = *model;
|
||||||
device_model.Init(model, 0, tree_limit, gpu_id);
|
|
||||||
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
|
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
|
||||||
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
|
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
|
||||||
auto d_nodes = device_model.nodes.ConstDeviceSpan();
|
auto d_nodes = device_model.nodes.ConstDeviceSpan();
|
||||||
@ -462,14 +525,45 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
|||||||
|
|
||||||
paths->resize(path_segments.back());
|
paths->resize(path_segments.back());
|
||||||
|
|
||||||
auto d_paths = paths->data().get();
|
auto d_paths = dh::ToSpan(*paths);
|
||||||
auto d_info = info.data().get();
|
auto d_info = info.data().get();
|
||||||
auto d_stats = device_model.stats.ConstDeviceSpan();
|
auto d_stats = device_model.stats.ConstDeviceSpan();
|
||||||
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
|
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
|
||||||
auto d_path_segments = path_segments.data().get();
|
auto d_path_segments = path_segments.data().get();
|
||||||
|
|
||||||
|
auto d_split_types = device_model.split_types.ConstDeviceSpan();
|
||||||
|
auto d_cat_segments = device_model.categories_tree_segments.ConstDeviceSpan();
|
||||||
|
auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan();
|
||||||
|
|
||||||
|
size_t max_cat = 0;
|
||||||
|
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
|
||||||
|
common::IsCatOp{})) {
|
||||||
|
dh::PinnedMemory pinned;
|
||||||
|
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1);
|
||||||
|
auto max_elem_it = dh::MakeTransformIterator<size_t>(
|
||||||
|
dh::tbegin(d_cat_node_segments),
|
||||||
|
[] __device__(RegTree::Segment seg) { return seg.size; });
|
||||||
|
size_t max_cat_it =
|
||||||
|
thrust::max_element(thrust::device, max_elem_it,
|
||||||
|
max_elem_it + d_cat_node_segments.size()) -
|
||||||
|
max_elem_it;
|
||||||
|
dh::safe_cuda(cudaMemcpy(h_max_cat.data(),
|
||||||
|
d_cat_node_segments.data() + max_cat_it,
|
||||||
|
h_max_cat.size_bytes(), cudaMemcpyDeviceToHost));
|
||||||
|
max_cat = h_max_cat[0].size;
|
||||||
|
CHECK_GE(max_cat, 1);
|
||||||
|
path_categories->resize(max_cat * paths->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto d_model_categories = device_model.categories.DeviceSpan();
|
||||||
|
common::Span<uint32_t> d_path_categories = dh::ToSpan(*path_categories);
|
||||||
|
|
||||||
dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
|
dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
|
||||||
auto path_info = d_info[idx];
|
auto path_info = d_info[idx];
|
||||||
size_t tree_offset = d_tree_segments[path_info.tree_idx];
|
size_t tree_offset = d_tree_segments[path_info.tree_idx];
|
||||||
|
TreeView tree{0, path_info.tree_idx, d_nodes,
|
||||||
|
d_tree_segments, d_split_types, d_cat_segments,
|
||||||
|
d_cat_node_segments, d_model_categories};
|
||||||
int group = d_tree_group[path_info.tree_idx];
|
int group = d_tree_group[path_info.tree_idx];
|
||||||
size_t child_idx = path_info.leaf_position;
|
size_t child_idx = path_info.leaf_position;
|
||||||
auto child = d_nodes[child_idx];
|
auto child = d_nodes[child_idx];
|
||||||
@ -481,20 +575,38 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
|||||||
double child_cover = d_stats[child_idx].sum_hess;
|
double child_cover = d_stats[child_idx].sum_hess;
|
||||||
double parent_cover = d_stats[parent_idx].sum_hess;
|
double parent_cover = d_stats[parent_idx].sum_hess;
|
||||||
double zero_fraction = child_cover / parent_cover;
|
double zero_fraction = child_cover / parent_cover;
|
||||||
auto parent = d_nodes[parent_idx];
|
auto parent = tree.d_tree[child.Parent()];
|
||||||
|
|
||||||
bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
|
bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
|
||||||
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
|
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
|
||||||
(parent.DefaultLeft() && is_left_path);
|
(parent.DefaultLeft() && is_left_path);
|
||||||
float lower_bound = is_left_path ? -inf : parent.SplitCond();
|
|
||||||
float upper_bound = is_left_path ? parent.SplitCond() : inf;
|
float lower_bound = -inf;
|
||||||
d_paths[output_position--] = {
|
float upper_bound = inf;
|
||||||
idx, parent.SplitIndex(), group, lower_bound,
|
common::CatBitField bits;
|
||||||
upper_bound, is_missing_path, zero_fraction, v};
|
if (common::IsCat(tree.cats.split_type, child.Parent())) {
|
||||||
|
auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat);
|
||||||
|
size_t size = tree.cats.node_ptr[child.Parent()].size;
|
||||||
|
auto node_cats = tree.cats.categories.subspan(tree.cats.node_ptr[child.Parent()].beg, size);
|
||||||
|
SPAN_CHECK(path_cats.size() >= node_cats.size());
|
||||||
|
for (size_t i = 0; i < node_cats.size(); ++i) {
|
||||||
|
path_cats[i] = is_left_path ? ~node_cats[i] : node_cats[i];
|
||||||
|
}
|
||||||
|
bits = common::CatBitField{path_cats};
|
||||||
|
} else {
|
||||||
|
lower_bound = is_left_path ? -inf : parent.SplitCond();
|
||||||
|
upper_bound = is_left_path ? parent.SplitCond() : inf;
|
||||||
|
}
|
||||||
|
d_paths[output_position--] =
|
||||||
|
gpu_treeshap::PathElement<ShapSplitCondition>{
|
||||||
|
idx, parent.SplitIndex(),
|
||||||
|
group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits},
|
||||||
|
zero_fraction, v};
|
||||||
child_idx = parent_idx;
|
child_idx = parent_idx;
|
||||||
child = parent;
|
child = parent;
|
||||||
}
|
}
|
||||||
// Root node has feature -1
|
// Root node has feature -1
|
||||||
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
|
d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v};
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -696,11 +808,16 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
void PredictContribution(DMatrix* p_fmat,
|
void PredictContribution(DMatrix* p_fmat,
|
||||||
HostDeviceVector<bst_float>* out_contribs,
|
HostDeviceVector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model, unsigned tree_end,
|
const gbm::GBTreeModel& model, unsigned tree_end,
|
||||||
std::vector<bst_float> const*,
|
std::vector<bst_float> const* tree_weights,
|
||||||
bool approximate, int,
|
bool approximate, int,
|
||||||
unsigned) const override {
|
unsigned) const override {
|
||||||
|
std::string not_implemented{"contribution is not implemented in GPU "
|
||||||
|
"predictor, use `cpu_predictor` instead."};
|
||||||
if (approximate) {
|
if (approximate) {
|
||||||
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
|
LOG(FATAL) << "Approximated " << not_implemented;
|
||||||
|
}
|
||||||
|
if (tree_weights != nullptr) {
|
||||||
|
LOG(FATAL) << "Dart booster feature " << not_implemented;
|
||||||
}
|
}
|
||||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||||
@ -718,16 +835,21 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
out_contribs->Fill(0.0f);
|
out_contribs->Fill(0.0f);
|
||||||
auto phis = out_contribs->DeviceSpan();
|
auto phis = out_contribs->DeviceSpan();
|
||||||
|
|
||||||
dh::device_vector<gpu_treeshap::PathElement> device_paths;
|
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
|
||||||
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
|
device_paths;
|
||||||
|
DeviceModel d_model;
|
||||||
|
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
|
||||||
|
dh::device_vector<uint32_t> categories;
|
||||||
|
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
|
||||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
batch.data.SetDevice(generic_param_->gpu_id);
|
batch.data.SetDevice(generic_param_->gpu_id);
|
||||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||||
model.learner_model_param->num_feature);
|
model.learner_model_param->num_feature);
|
||||||
gpu_treeshap::GPUTreeShap(
|
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
|
||||||
X, device_paths.begin(), device_paths.end(), ngroup,
|
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
|
||||||
phis.data() + batch.base_rowid * contributions_columns, phis.size());
|
X, device_paths.begin(), device_paths.end(), ngroup, begin,
|
||||||
|
dh::tend(phis));
|
||||||
}
|
}
|
||||||
// Add the base margin term to last column
|
// Add the base margin term to last column
|
||||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||||
@ -746,11 +868,15 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
HostDeviceVector<bst_float>* out_contribs,
|
HostDeviceVector<bst_float>* out_contribs,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
unsigned tree_end,
|
unsigned tree_end,
|
||||||
std::vector<bst_float> const*,
|
std::vector<bst_float> const* tree_weights,
|
||||||
bool approximate) const override {
|
bool approximate) const override {
|
||||||
|
std::string not_implemented{"contribution is not implemented in GPU "
|
||||||
|
"predictor, use `cpu_predictor` instead."};
|
||||||
if (approximate) {
|
if (approximate) {
|
||||||
LOG(FATAL) << "[Internal error]: " << __func__
|
LOG(FATAL) << "Approximated " << not_implemented;
|
||||||
<< " approximate is not implemented in GPU Predictor.";
|
}
|
||||||
|
if (tree_weights != nullptr) {
|
||||||
|
LOG(FATAL) << "Dart booster feature " << not_implemented;
|
||||||
}
|
}
|
||||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||||
out_contribs->SetDevice(generic_param_->gpu_id);
|
out_contribs->SetDevice(generic_param_->gpu_id);
|
||||||
@ -769,16 +895,21 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
out_contribs->Fill(0.0f);
|
out_contribs->Fill(0.0f);
|
||||||
auto phis = out_contribs->DeviceSpan();
|
auto phis = out_contribs->DeviceSpan();
|
||||||
|
|
||||||
dh::device_vector<gpu_treeshap::PathElement> device_paths;
|
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
|
||||||
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
|
device_paths;
|
||||||
|
DeviceModel d_model;
|
||||||
|
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
|
||||||
|
dh::device_vector<uint32_t> categories;
|
||||||
|
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
|
||||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
batch.data.SetDevice(generic_param_->gpu_id);
|
batch.data.SetDevice(generic_param_->gpu_id);
|
||||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||||
model.learner_model_param->num_feature);
|
model.learner_model_param->num_feature);
|
||||||
gpu_treeshap::GPUTreeShapInteractions(
|
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
|
||||||
X, device_paths.begin(), device_paths.end(), ngroup,
|
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
|
||||||
phis.data() + batch.base_rowid * contributions_columns, phis.size());
|
X, device_paths.begin(), device_paths.end(), ngroup, begin,
|
||||||
|
dh::tend(phis));
|
||||||
}
|
}
|
||||||
// Add the base margin term to last column
|
// Add the base margin term to last column
|
||||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||||
|
|||||||
@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,
|
|||||||
|
|
||||||
// recursive computation of SHAP values for a decision tree
|
// recursive computation of SHAP values for a decision tree
|
||||||
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
||||||
unsigned node_index, unsigned unique_depth,
|
bst_node_t node_index, unsigned unique_depth,
|
||||||
PathElement *parent_unique_path,
|
PathElement *parent_unique_path,
|
||||||
bst_float parent_zero_fraction,
|
bst_float parent_zero_fraction,
|
||||||
bst_float parent_one_fraction, int parent_feature_index,
|
bst_float parent_one_fraction, int parent_feature_index,
|
||||||
@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
|
|||||||
// internal node
|
// internal node
|
||||||
} else {
|
} else {
|
||||||
// find which branch is "hot" (meaning x would follow it)
|
// find which branch is "hot" (meaning x would follow it)
|
||||||
unsigned hot_index = 0;
|
auto const &cats = this->GetCategoriesMatrix();
|
||||||
if (feat.IsMissing(split_index)) {
|
bst_node_t hot_index = predictor::GetNextNode<true, true>(
|
||||||
hot_index = node.DefaultChild();
|
node, node_index, feat.GetFvalue(split_index),
|
||||||
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
|
feat.IsMissing(split_index), cats);
|
||||||
hot_index = node.LeftChild();
|
|
||||||
} else {
|
const auto cold_index =
|
||||||
hot_index = node.RightChild();
|
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
|
||||||
}
|
|
||||||
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
|
|
||||||
node.RightChild() : node.LeftChild());
|
|
||||||
const bst_float w = this->Stat(node_index).sum_hess;
|
const bst_float w = this->Stat(node_index).sum_hess;
|
||||||
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
|
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
|
||||||
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
|
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
|
||||||
|
|||||||
@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST(CpuPredictor, IterationRange) {
|
||||||
|
TestIterationRange("cpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, ExternalMemory) {
|
TEST(CpuPredictor, ExternalMemory) {
|
||||||
dmlc::TemporaryDirectory tmpdir;
|
dmlc::TemporaryDirectory tmpdir;
|
||||||
std::string filename = tmpdir.path + "/big.libsvm";
|
std::string filename = tmpdir.path + "/big.libsvm";
|
||||||
|
|||||||
@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GPUPredictor, IterationRange) {
|
||||||
|
TestIterationRange("gpu_predictor");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST(GPUPredictor, CategoricalPrediction) {
|
TEST(GPUPredictor, CategoricalPrediction) {
|
||||||
TestCategoricalPrediction("gpu_predictor");
|
TestCategoricalPrediction("gpu_predictor");
|
||||||
}
|
}
|
||||||
|
|||||||
@ -281,4 +281,78 @@ void TestCategoricalPredictLeaf(StringView name) {
|
|||||||
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
predictor->PredictLeaf(m.get(), &out_predictions.predictions, model);
|
||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void TestIterationRange(std::string name) {
|
||||||
|
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3;
|
||||||
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({dmat})};
|
||||||
|
|
||||||
|
learner->SetParams(Args{{"num_parallel_tree", std::to_string(kForest)},
|
||||||
|
{"predictor", name}});
|
||||||
|
|
||||||
|
size_t kIters = 10;
|
||||||
|
for (size_t i = 0; i < kIters; ++i) {
|
||||||
|
learner->UpdateOneIter(i, dmat);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bound = false;
|
||||||
|
std::unique_ptr<Learner> sliced {learner->Slice(0, 3, 1, &bound)};
|
||||||
|
ASSERT_FALSE(bound);
|
||||||
|
|
||||||
|
HostDeviceVector<float> out_predt_sliced;
|
||||||
|
HostDeviceVector<float> out_predt_ranged;
|
||||||
|
|
||||||
|
// margin
|
||||||
|
{
|
||||||
|
sliced->Predict(dmat, true, &out_predt_sliced, 0, 0, false, false, false,
|
||||||
|
false, false);
|
||||||
|
|
||||||
|
learner->Predict(dmat, true, &out_predt_ranged, 0, 3, false, false, false,
|
||||||
|
false, false);
|
||||||
|
|
||||||
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
|
auto const &h_range = out_predt_ranged.HostVector();
|
||||||
|
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||||
|
ASSERT_EQ(h_sliced, h_range);
|
||||||
|
}
|
||||||
|
|
||||||
|
// SHAP
|
||||||
|
{
|
||||||
|
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
|
||||||
|
true, false, false);
|
||||||
|
|
||||||
|
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, true,
|
||||||
|
false, false);
|
||||||
|
|
||||||
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
|
auto const &h_range = out_predt_ranged.HostVector();
|
||||||
|
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||||
|
ASSERT_EQ(h_sliced, h_range);
|
||||||
|
}
|
||||||
|
|
||||||
|
// SHAP interaction
|
||||||
|
{
|
||||||
|
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, false,
|
||||||
|
false, false, true);
|
||||||
|
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, false, false,
|
||||||
|
false, true);
|
||||||
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
|
auto const &h_range = out_predt_ranged.HostVector();
|
||||||
|
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||||
|
ASSERT_EQ(h_sliced, h_range);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Leaf
|
||||||
|
{
|
||||||
|
sliced->Predict(dmat, false, &out_predt_sliced, 0, 0, false, true,
|
||||||
|
false, false, false);
|
||||||
|
learner->Predict(dmat, false, &out_predt_ranged, 0, 3, false, true, false,
|
||||||
|
false, false);
|
||||||
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
|
auto const &h_range = out_predt_ranged.HostVector();
|
||||||
|
ASSERT_EQ(h_sliced.size(), h_range.size());
|
||||||
|
ASSERT_EQ(h_sliced, h_range);
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -68,6 +68,8 @@ void TestPredictionWithLesserFeatures(std::string preditor_name);
|
|||||||
void TestCategoricalPrediction(std::string name);
|
void TestCategoricalPrediction(std::string name);
|
||||||
|
|
||||||
void TestCategoricalPredictLeaf(StringView name);
|
void TestCategoricalPredictLeaf(StringView name);
|
||||||
|
|
||||||
|
void TestIterationRange(std::string name);
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|
||||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost.compat import PANDAS_INSTALLED
|
from xgboost.compat import PANDAS_INSTALLED
|
||||||
|
|
||||||
from hypothesis import given, strategies, assume, settings, note
|
from hypothesis import given, strategies, assume, settings
|
||||||
|
|
||||||
if PANDAS_INSTALLED:
|
if PANDAS_INSTALLED:
|
||||||
from hypothesis.extra.pandas import column, data_frames, range_indexes
|
from hypothesis.extra.pandas import column, data_frames, range_indexes
|
||||||
@ -275,6 +275,25 @@ class TestGPUPredict:
|
|||||||
margin,
|
margin,
|
||||||
1e-3, 1e-3)
|
1e-3, 1e-3)
|
||||||
|
|
||||||
|
def test_shap_categorical(self):
|
||||||
|
X, y = tm.make_categorical(100, 20, 7, False)
|
||||||
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
|
booster = xgb.train({"tree_method": "gpu_hist"}, Xy, num_boost_round=10)
|
||||||
|
|
||||||
|
booster.set_param({"predictor": "gpu_predictor"})
|
||||||
|
shap = booster.predict(Xy, pred_contribs=True)
|
||||||
|
margin = booster.predict(Xy, output_margin=True)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3
|
||||||
|
)
|
||||||
|
|
||||||
|
booster.set_param({"predictor": "cpu_predictor"})
|
||||||
|
shap = booster.predict(Xy, pred_contribs=True)
|
||||||
|
margin = booster.predict(Xy, output_margin=True)
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
np.sum(shap, axis=len(shap.shape) - 1), margin, rtol=1e-3
|
||||||
|
)
|
||||||
|
|
||||||
def test_predict_leaf_basic(self):
|
def test_predict_leaf_basic(self):
|
||||||
gpu_leaf = run_predict_leaf('gpu_predictor')
|
gpu_leaf = run_predict_leaf('gpu_predictor')
|
||||||
cpu_leaf = run_predict_leaf('cpu_predictor')
|
cpu_leaf = run_predict_leaf('cpu_predictor')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user