Implement categorical data support for SHAP. (#7053)

* Add CPU implementation.
* Update GPUTreeSHAP.
* Add GPU implementation by defining custom split condition.
This commit is contained in:
Jiaming Yuan
2021-06-25 19:02:46 +08:00
committed by GitHub
parent 663136aa08
commit 8fa32fdda2
12 changed files with 287 additions and 50 deletions

View File

@@ -87,9 +87,11 @@ struct BitFieldContainer {
BitFieldContainer() = default;
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
BitFieldContainer &operator=(BitFieldContainer const &that) = default;
BitFieldContainer &operator=(BitFieldContainer &&that) = default;
common::Span<value_type> Bits() { return bits_; }
common::Span<value_type const> Bits() const { return bits_; }
XGBOOST_DEVICE common::Span<value_type> Bits() { return bits_; }
XGBOOST_DEVICE common::Span<value_type const> Bits() const { return bits_; }
/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.

View File

@@ -42,6 +42,12 @@ inline XGBOOST_DEVICE bool Decision(common::Span<uint32_t const> cats, bst_cat_t
return !s_cats.Check(cat);
}
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};
using CatBitField = LBitField32;
using KCatBitField = CLBitField32;
} // namespace common

View File

@@ -8,6 +8,7 @@
#include "device_helpers.cuh"
#include "quantile.h"
#include "timer.h"
#include "categorical.h"
namespace xgboost {
namespace common {
@@ -17,11 +18,6 @@ using WQSketch = WQuantileSketch<bst_float, bst_float>;
using SketchEntry = WQSketch::Entry;
namespace detail {
struct IsCatOp {
XGBOOST_DEVICE bool operator()(FeatureType ft) {
return ft == FeatureType::kCategorical;
}
};
struct SketchUnique {
XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const {
return a.value - b.value == 0;
@@ -122,7 +118,7 @@ class SketchContainer {
has_categorical_ =
!d_feature_types.empty() &&
thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types),
detail::IsCatOp{});
common::IsCatOp{});
timer_.Init(__func__);
}