Updates to GPUTreeShap (#6087)
* Extract paths on device * Update GPUTreeShap
This commit is contained in:
parent
0e2d5669f6
commit
2e907abdb8
@ -1 +1 @@
|
||||
Subproject commit 04410099299ec918c75d00e385da35cf2e5aa87c
|
||||
Subproject commit 1de23c95ff07d086db02837fb4a746b6924abbd5
|
||||
@ -337,6 +337,9 @@ class RegTree : public Model {
|
||||
/*! \brief get const reference to nodes */
|
||||
const std::vector<Node>& GetNodes() const { return nodes_; }
|
||||
|
||||
/*! \brief get const reference to stats */
|
||||
const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
|
||||
|
||||
/*! \brief get node statistics given nid */
|
||||
RTreeNodeStat& Stat(int nid) {
|
||||
return stats_[nid];
|
||||
|
||||
@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
|
||||
template class HostDeviceVector<uint64_t>; // bst_row_t
|
||||
template class HostDeviceVector<uint32_t>; // bst_feature_t
|
||||
template class HostDeviceVector<RegTree::Node>;
|
||||
template class HostDeviceVector<RTreeNodeStat>;
|
||||
|
||||
#if defined(__APPLE__)
|
||||
/*
|
||||
|
||||
@ -223,6 +223,7 @@ class DeviceModel {
|
||||
public:
|
||||
// Need to lazily construct the vectors because GPU id is only known at runtime
|
||||
HostDeviceVector<RegTree::Node> nodes;
|
||||
HostDeviceVector<RTreeNodeStat> stats;
|
||||
HostDeviceVector<size_t> tree_segments;
|
||||
HostDeviceVector<int> tree_group;
|
||||
size_t tree_beg_; // NOLINT
|
||||
@ -246,22 +247,116 @@ class DeviceModel {
|
||||
|
||||
nodes = std::move(HostDeviceVector<RegTree::Node>(h_tree_segments.back(), RegTree::Node(),
|
||||
gpu_id));
|
||||
auto& h_nodes = nodes.HostVector();
|
||||
stats = std::move(HostDeviceVector<RTreeNodeStat>(h_tree_segments.back(),
|
||||
RTreeNodeStat(), gpu_id));
|
||||
auto d_nodes = nodes.DevicePointer();
|
||||
auto d_stats = stats.DevicePointer();
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
|
||||
std::copy(src_nodes.begin(), src_nodes.end(),
|
||||
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
|
||||
auto& src_stats = model.trees.at(tree_idx)->GetStats();
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
d_nodes + h_tree_segments[tree_idx - tree_begin], src_nodes.data(),
|
||||
sizeof(RegTree::Node) * src_nodes.size(), cudaMemcpyDefault));
|
||||
dh::safe_cuda(cudaMemcpyAsync(
|
||||
d_stats + h_tree_segments[tree_idx - tree_begin], src_stats.data(),
|
||||
sizeof(RTreeNodeStat) * src_stats.size(), cudaMemcpyDefault));
|
||||
}
|
||||
|
||||
tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
|
||||
auto& h_tree_group = tree_group.HostVector();
|
||||
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());
|
||||
auto d_tree_group = tree_group.DevicePointer();
|
||||
dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(),
|
||||
sizeof(int) * model.tree_info.size(),
|
||||
cudaMemcpyDefault));
|
||||
this->tree_beg_ = tree_begin;
|
||||
this->tree_end_ = tree_end;
|
||||
this->num_group = model.learner_model_param->num_output_group;
|
||||
}
|
||||
};
|
||||
|
||||
struct PathInfo {
|
||||
int64_t leaf_position; // -1 not a leaf
|
||||
size_t length;
|
||||
size_t tree_idx;
|
||||
};
|
||||
|
||||
// Transform model into path element form for GPUTreeShap
|
||||
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
||||
const gbm::GBTreeModel& model, size_t tree_limit,
|
||||
int gpu_id) {
|
||||
DeviceModel device_model;
|
||||
device_model.Init(model, 0, tree_limit, gpu_id);
|
||||
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
|
||||
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
|
||||
auto d_nodes = device_model.nodes.ConstDeviceSpan();
|
||||
auto d_tree_segments = device_model.tree_segments.ConstDeviceSpan();
|
||||
auto nodes_transform = dh::MakeTransformIterator<PathInfo>(
|
||||
thrust::make_counting_iterator(0ull), [=] __device__(size_t idx) {
|
||||
auto n = d_nodes[idx];
|
||||
if (!n.IsLeaf() || n.IsDeleted()) {
|
||||
return PathInfo{-1, 0, 0};
|
||||
}
|
||||
size_t tree_idx =
|
||||
dh::SegmentId(d_tree_segments.begin(), d_tree_segments.end(), idx);
|
||||
size_t tree_offset = d_tree_segments[tree_idx];
|
||||
size_t path_length = 1;
|
||||
while (!n.IsRoot()) {
|
||||
n = d_nodes[n.Parent() + tree_offset];
|
||||
path_length++;
|
||||
}
|
||||
return PathInfo{int64_t(idx), path_length, tree_idx};
|
||||
});
|
||||
auto end = thrust::copy_if(
|
||||
thrust::cuda::par(alloc), nodes_transform,
|
||||
nodes_transform + d_nodes.size(), info.begin(),
|
||||
[=] __device__(const PathInfo& e) { return e.leaf_position != -1; });
|
||||
info.resize(end - info.begin());
|
||||
auto length_iterator = dh::MakeTransformIterator<size_t>(
|
||||
info.begin(),
|
||||
[=] __device__(const PathInfo& info) { return info.length; });
|
||||
dh::caching_device_vector<size_t> path_segments(info.size() + 1);
|
||||
thrust::exclusive_scan(thrust::cuda::par(alloc), length_iterator,
|
||||
length_iterator + info.size() + 1,
|
||||
path_segments.begin());
|
||||
|
||||
paths->resize(path_segments.back());
|
||||
|
||||
auto d_paths = paths->data().get();
|
||||
auto d_info = info.data().get();
|
||||
auto d_stats = device_model.stats.ConstDeviceSpan();
|
||||
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
|
||||
auto d_path_segments = path_segments.data().get();
|
||||
dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
|
||||
auto path_info = d_info[idx];
|
||||
size_t tree_offset = d_tree_segments[path_info.tree_idx];
|
||||
int group = d_tree_group[path_info.tree_idx];
|
||||
size_t child_idx = path_info.leaf_position;
|
||||
auto child = d_nodes[child_idx];
|
||||
float v = child.LeafValue();
|
||||
const float inf = std::numeric_limits<float>::infinity();
|
||||
size_t output_position = d_path_segments[idx + 1] - 1;
|
||||
while (!child.IsRoot()) {
|
||||
size_t parent_idx = tree_offset + child.Parent();
|
||||
double child_cover = d_stats[child_idx].sum_hess;
|
||||
double parent_cover = d_stats[parent_idx].sum_hess;
|
||||
double zero_fraction = child_cover / parent_cover;
|
||||
auto parent = d_nodes[parent_idx];
|
||||
bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
|
||||
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
|
||||
(parent.DefaultLeft() && is_left_path);
|
||||
float lower_bound = is_left_path ? -inf : parent.SplitCond();
|
||||
float upper_bound = is_left_path ? parent.SplitCond() : inf;
|
||||
d_paths[output_position--] = {
|
||||
idx, parent.SplitIndex(), group, lower_bound,
|
||||
upper_bound, is_missing_path, zero_fraction, v};
|
||||
child_idx = parent_idx;
|
||||
child = parent;
|
||||
}
|
||||
// Root node has feature -1
|
||||
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
void PredictInternal(const SparsePage& batch, size_t num_features,
|
||||
@ -495,17 +590,19 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
margin.empty() ? base_score : margin[idx];
|
||||
});
|
||||
|
||||
const auto& paths = this->ExtractPaths(model, real_ntree_limit);
|
||||
dh::device_vector<gpu_treeshap::PathElement> device_paths;
|
||||
ExtractPaths(&device_paths, model, real_ntree_limit,
|
||||
generic_param_->gpu_id);
|
||||
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature);
|
||||
gpu_treeshap::GPUTreeShap(
|
||||
X, paths, ngroup,
|
||||
X, device_paths.begin(), device_paths.end(), ngroup,
|
||||
phis.data().get() + batch.base_rowid * contributions_columns);
|
||||
}
|
||||
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
|
||||
dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(),
|
||||
sizeof(float) * phis.size(),
|
||||
cudaMemcpyDefault));
|
||||
}
|
||||
@ -563,49 +660,6 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<gpu_treeshap::PathElement> ExtractPaths(
|
||||
const gbm::GBTreeModel& model, size_t tree_limit) {
|
||||
std::vector<gpu_treeshap::PathElement> paths;
|
||||
size_t path_idx = 0;
|
||||
CHECK_LE(tree_limit, model.trees.size());
|
||||
for (auto i = 0ull; i < tree_limit; i++) {
|
||||
const auto& tree = *model.trees.at(i);
|
||||
size_t group = model.tree_info[i];
|
||||
const auto& nodes = tree.GetNodes();
|
||||
for (auto j = 0ull; j < nodes.size(); j++) {
|
||||
if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) {
|
||||
auto child = nodes[j];
|
||||
float v = child.LeafValue();
|
||||
size_t child_idx = j;
|
||||
const float inf = std::numeric_limits<float>::infinity();
|
||||
while (!child.IsRoot()) {
|
||||
float child_cover = tree.Stat(child_idx).sum_hess;
|
||||
float parent_cover = tree.Stat(child.Parent()).sum_hess;
|
||||
float zero_fraction = child_cover / parent_cover;
|
||||
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
|
||||
auto parent = nodes[child.Parent()];
|
||||
CHECK(parent.LeftChild() == child_idx ||
|
||||
parent.RightChild() == child_idx);
|
||||
bool is_left_path = parent.LeftChild() == child_idx;
|
||||
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
|
||||
(parent.DefaultLeft() && is_left_path);
|
||||
float lower_bound = is_left_path ? -inf : parent.SplitCond();
|
||||
float upper_bound = is_left_path ? parent.SplitCond() : inf;
|
||||
paths.emplace_back(path_idx, parent.SplitIndex(), group,
|
||||
lower_bound, upper_bound, is_missing_path,
|
||||
zero_fraction, v);
|
||||
child_idx = child.Parent();
|
||||
child = parent;
|
||||
}
|
||||
// Root node has feature -1
|
||||
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
|
||||
path_idx++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return paths;
|
||||
}
|
||||
|
||||
std::mutex lock_;
|
||||
DeviceModel model_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user