GPUTreeShap (#6038)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <GPUTreeShap/gpu_treeshap.h>
|
||||
#include <memory>
|
||||
|
||||
#include "xgboost/data.h"
|
||||
@@ -27,72 +28,79 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor);
|
||||
struct SparsePageView {
|
||||
common::Span<const Entry> d_data;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
bst_feature_t num_features;
|
||||
|
||||
XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
|
||||
common::Span<const bst_row_t> row_ptr) :
|
||||
d_data{data}, d_row_ptr{row_ptr} {}
|
||||
common::Span<const bst_row_t> row_ptr,
|
||||
bst_feature_t num_features)
|
||||
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
|
||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
// Binary search
|
||||
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
|
||||
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
|
||||
if (end_ptr - begin_ptr == this->NumCols()) {
|
||||
// Bypass span check for dense data
|
||||
return d_data.data()[d_row_ptr[ridx] + fidx].fvalue;
|
||||
}
|
||||
common::Span<const Entry>::iterator previous_middle;
|
||||
while (end_ptr != begin_ptr) {
|
||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||
if (middle == previous_middle) {
|
||||
break;
|
||||
} else {
|
||||
previous_middle = middle;
|
||||
}
|
||||
|
||||
if (middle->index == fidx) {
|
||||
return middle->fvalue;
|
||||
} else if (middle->index < fidx) {
|
||||
begin_ptr = middle;
|
||||
} else {
|
||||
end_ptr = middle;
|
||||
}
|
||||
}
|
||||
// Value is missing
|
||||
return nanf("");
|
||||
}
|
||||
XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
|
||||
XGBOOST_DEVICE size_t NumCols() const { return num_features; }
|
||||
};
|
||||
|
||||
struct SparsePageLoader {
|
||||
bool use_shared;
|
||||
common::Span<const bst_row_t> d_row_ptr;
|
||||
common::Span<const Entry> d_data;
|
||||
bst_feature_t num_features;
|
||||
SparsePageView data;
|
||||
float* smem;
|
||||
size_t entry_start;
|
||||
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start)
|
||||
: use_shared(use_shared),
|
||||
d_row_ptr(data.d_row_ptr),
|
||||
d_data(data.d_data),
|
||||
num_features(num_features),
|
||||
data(data),
|
||||
entry_start(entry_start) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
// Copy instances
|
||||
if (use_shared) {
|
||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int shared_elements = blockDim.x * num_features;
|
||||
int shared_elements = blockDim.x * data.num_features;
|
||||
dh::BlockFill(smem, shared_elements, nanf(""));
|
||||
__syncthreads();
|
||||
if (global_idx < num_rows) {
|
||||
bst_uint elem_begin = d_row_ptr[global_idx];
|
||||
bst_uint elem_end = d_row_ptr[global_idx + 1];
|
||||
bst_uint elem_begin = data.d_row_ptr[global_idx];
|
||||
bst_uint elem_end = data.d_row_ptr[global_idx + 1];
|
||||
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
|
||||
Entry elem = d_data[elem_idx - entry_start];
|
||||
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
|
||||
Entry elem = data.d_data[elem_idx - entry_start];
|
||||
smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
__device__ float GetFvalue(int ridx, int fidx) const {
|
||||
__device__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * num_features + fidx];
|
||||
return smem[threadIdx.x * data.num_features + fidx];
|
||||
} else {
|
||||
// Binary search
|
||||
auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start);
|
||||
auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start);
|
||||
common::Span<const Entry>::iterator previous_middle;
|
||||
while (end_ptr != begin_ptr) {
|
||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||
if (middle == previous_middle) {
|
||||
break;
|
||||
} else {
|
||||
previous_middle = middle;
|
||||
}
|
||||
|
||||
if (middle->index == fidx) {
|
||||
return middle->fvalue;
|
||||
} else if (middle->index < fidx) {
|
||||
begin_ptr = middle;
|
||||
} else {
|
||||
end_ptr = middle;
|
||||
}
|
||||
}
|
||||
// Value is missing
|
||||
return nanf("");
|
||||
return data.GetElement(ridx, fidx);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -103,7 +111,7 @@ struct EllpackLoader {
|
||||
bst_feature_t num_features, bst_row_t num_rows,
|
||||
size_t entry_start)
|
||||
: matrix{m} {}
|
||||
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
|
||||
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||
if (gidx == -1) {
|
||||
return nan("");
|
||||
@@ -150,7 +158,7 @@ struct DeviceAdapterLoader {
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
|
||||
DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
|
||||
if (use_shared) {
|
||||
return smem[threadIdx.x * columns + fidx];
|
||||
}
|
||||
@@ -163,7 +171,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
Loader* loader) {
|
||||
RegTree::Node n = tree[0];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
|
||||
float fvalue = loader->GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(fvalue)) {
|
||||
n = tree[n.DefaultChild()];
|
||||
@@ -273,7 +281,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
use_shared = false;
|
||||
}
|
||||
size_t entry_start = 0;
|
||||
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
PredictKernel<SparsePageLoader, SparsePageView>,
|
||||
data,
|
||||
@@ -447,6 +456,60 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
if (approximate) {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " approximate is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
uint32_t real_ntree_limit =
|
||||
ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
|
||||
const int ngroup = model.learner_model_param->num_output_group;
|
||||
CHECK_NE(ngroup, 0);
|
||||
// allocate space for (number of features + bias) times the number of rows
|
||||
std::vector<bst_float>& contribs = *out_contribs;
|
||||
size_t contributions_columns =
|
||||
model.learner_model_param->num_feature + 1; // +1 for bias
|
||||
contribs.resize(p_fmat->Info().num_row_ * contributions_columns *
|
||||
model.learner_model_param->num_output_group);
|
||||
dh::TemporaryArray<float> phis(contribs.size(), 0.0);
|
||||
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
|
||||
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
|
||||
float base_score = model.learner_model_param->base_score;
|
||||
auto d_phis = phis.data().get();
|
||||
// Add the base margin term to last column
|
||||
dh::LaunchN(
|
||||
generic_param_->gpu_id,
|
||||
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
|
||||
[=] __device__(size_t idx) {
|
||||
d_phis[(idx + 1) * contributions_columns - 1] =
|
||||
margin.empty() ? base_score : margin[idx];
|
||||
});
|
||||
|
||||
const auto& paths = this->ExtractPaths(model, real_ntree_limit);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature);
|
||||
gpu_treeshap::GPUTreeShap(
|
||||
X, paths, ngroup,
|
||||
phis.data().get() + batch.base_rowid * contributions_columns);
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
|
||||
sizeof(float) * phis.size(),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitOutPredictions(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
@@ -478,16 +541,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictContribution(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model, unsigned ntree_limit,
|
||||
std::vector<bst_float>* tree_weights,
|
||||
bool approximate, int condition,
|
||||
unsigned condition_feature) override {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictInteractionContributions(DMatrix* p_fmat,
|
||||
std::vector<bst_float>* out_contribs,
|
||||
const gbm::GBTreeModel& model,
|
||||
@@ -510,6 +563,49 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<gpu_treeshap::PathElement> ExtractPaths(
|
||||
const gbm::GBTreeModel& model, size_t tree_limit) {
|
||||
std::vector<gpu_treeshap::PathElement> paths;
|
||||
size_t path_idx = 0;
|
||||
CHECK_LE(tree_limit, model.trees.size());
|
||||
for (auto i = 0ull; i < tree_limit; i++) {
|
||||
const auto& tree = *model.trees.at(i);
|
||||
size_t group = model.tree_info[i];
|
||||
const auto& nodes = tree.GetNodes();
|
||||
for (auto j = 0ull; j < nodes.size(); j++) {
|
||||
if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) {
|
||||
auto child = nodes[j];
|
||||
float v = child.LeafValue();
|
||||
size_t child_idx = j;
|
||||
const float inf = std::numeric_limits<float>::infinity();
|
||||
while (!child.IsRoot()) {
|
||||
float child_cover = tree.Stat(child_idx).sum_hess;
|
||||
float parent_cover = tree.Stat(child.Parent()).sum_hess;
|
||||
float zero_fraction = child_cover / parent_cover;
|
||||
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
|
||||
auto parent = nodes[child.Parent()];
|
||||
CHECK(parent.LeftChild() == child_idx ||
|
||||
parent.RightChild() == child_idx);
|
||||
bool is_left_path = parent.LeftChild() == child_idx;
|
||||
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
|
||||
(parent.DefaultLeft() && is_left_path);
|
||||
float lower_bound = is_left_path ? -inf : parent.SplitCond();
|
||||
float upper_bound = is_left_path ? parent.SplitCond() : inf;
|
||||
paths.emplace_back(path_idx, parent.SplitIndex(), group,
|
||||
lower_bound, upper_bound, is_missing_path,
|
||||
zero_fraction, v);
|
||||
child_idx = child.Parent();
|
||||
child = parent;
|
||||
}
|
||||
// Root node has feature -1
|
||||
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
|
||||
path_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return paths;
|
||||
}
|
||||
|
||||
std::mutex lock_;
|
||||
DeviceModel model_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
|
||||
Reference in New Issue
Block a user