Implement GPU predict leaf. (#6187)
This commit is contained in:
@@ -78,7 +78,7 @@ struct SparsePageLoader {
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start)
|
||||
: use_shared(use_shared),
|
||||
data(data),
|
||||
data(data),
|
||||
entry_start(entry_start) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
@@ -169,7 +169,7 @@ struct DeviceAdapterLoader {
|
||||
};
|
||||
|
||||
template <typename Loader>
|
||||
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree,
|
||||
common::Span<FeatureType const> split_types,
|
||||
common::Span<RegTree::Segment const> d_cat_ptrs,
|
||||
common::Span<uint32_t const> d_categories,
|
||||
@@ -201,6 +201,49 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
|
||||
return tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
template <typename Loader>
|
||||
__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
|
||||
Loader const& loader) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
||||
// Missing value
|
||||
if (isnan(fvalue)) {
|
||||
nidx = n.DefaultChild();
|
||||
n = tree[nidx];
|
||||
} else {
|
||||
if (fvalue < n.SplitCond()) {
|
||||
nidx = n.LeftChild();
|
||||
n = tree[nidx];
|
||||
} else {
|
||||
nidx = n.RightChild();
|
||||
n = tree[nidx];
|
||||
}
|
||||
}
|
||||
}
|
||||
return nidx;
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
__global__ void PredictLeafKernel(Data data,
|
||||
common::Span<const RegTree::Node> d_nodes,
|
||||
common::Span<float> d_out_predictions,
|
||||
common::Span<size_t const> d_tree_segments,
|
||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||
size_t num_rows, size_t entry_start, bool use_shared) {
|
||||
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (ridx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
||||
d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Loader, typename Data>
|
||||
__global__ void
|
||||
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
||||
@@ -437,6 +480,19 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
|
||||
});
|
||||
}
|
||||
|
||||
namespace {
|
||||
template <size_t kBlockThreads>
|
||||
size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
|
||||
// No way max_shared_memory_bytes that is equal to 0.
|
||||
CHECK_GT(max_shared_memory_bytes, 0);
|
||||
size_t shared_memory_bytes =
|
||||
static_cast<size_t>(sizeof(float) * cols * kBlockThreads);
|
||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||
shared_memory_bytes = 0;
|
||||
}
|
||||
return shared_memory_bytes;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
private:
|
||||
@@ -450,13 +506,10 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
size_t num_rows = batch.Size();
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
|
||||
|
||||
auto shared_memory_bytes =
|
||||
static_cast<size_t>(sizeof(float) * num_features * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes_) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes_);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
|
||||
size_t entry_start = 0;
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
num_features);
|
||||
@@ -608,13 +661,9 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
const uint32_t BLOCK_THREADS = 128;
|
||||
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(info.num_row_, BLOCK_THREADS));
|
||||
|
||||
auto shared_memory_bytes =
|
||||
static_cast<size_t>(sizeof(float) * m->NumColumns() * BLOCK_THREADS);
|
||||
bool use_shared = true;
|
||||
if (shared_memory_bytes > max_shared_memory_bytes) {
|
||||
shared_memory_bytes = 0;
|
||||
use_shared = false;
|
||||
}
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<BLOCK_THREADS>(info.num_col_, max_shared_memory_bytes);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
size_t entry_start = 0;
|
||||
|
||||
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
|
||||
@@ -780,11 +829,65 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
}
|
||||
|
||||
void PredictLeaf(DMatrix*, std::vector<bst_float>*,
|
||||
const gbm::GBTreeModel&,
|
||||
unsigned) override {
|
||||
LOG(FATAL) << "[Internal error]: " << __func__
|
||||
<< " is not implemented in GPU Predictor.";
|
||||
void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* predictions,
|
||||
const gbm::GBTreeModel& model,
|
||||
unsigned ntree_limit) override {
|
||||
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
|
||||
ConfigureDevice(generic_param_->gpu_id);
|
||||
|
||||
const MetaInfo& info = p_fmat->Info();
|
||||
constexpr uint32_t kBlockThreads = 128;
|
||||
size_t shared_memory_bytes =
|
||||
SharedMemoryBytes<kBlockThreads>(info.num_col_, max_shared_memory_bytes_);
|
||||
bool use_shared = shared_memory_bytes != 0;
|
||||
bst_feature_t num_features = info.num_col_;
|
||||
bst_row_t num_rows = info.num_row_;
|
||||
size_t entry_start = 0;
|
||||
|
||||
uint32_t real_ntree_limit = ntree_limit * model.learner_model_param->num_output_group;
|
||||
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
|
||||
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
|
||||
}
|
||||
predictions->SetDevice(generic_param_->gpu_id);
|
||||
predictions->Resize(num_rows * real_ntree_limit);
|
||||
model_.Init(model, 0, real_ntree_limit, generic_param_->gpu_id);
|
||||
|
||||
if (p_fmat->PageExists<SparsePage>()) {
|
||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||
batch.data.SetDevice(generic_param_->gpu_id);
|
||||
batch.offset.SetDevice(generic_param_->gpu_id);
|
||||
bst_row_t batch_offset = 0;
|
||||
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
|
||||
model.learner_model_param->num_feature};
|
||||
size_t num_rows = batch.Size();
|
||||
auto grid =
|
||||
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
||||
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
|
||||
model_.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.ConstDeviceSpan(),
|
||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared);
|
||||
batch_offset += batch.Size();
|
||||
}
|
||||
} else {
|
||||
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
|
||||
bst_row_t batch_offset = 0;
|
||||
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)};
|
||||
size_t num_rows = batch.Size();
|
||||
auto grid =
|
||||
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
|
||||
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
|
||||
model_.nodes.ConstDeviceSpan(),
|
||||
predictions->DeviceSpan().subspan(batch_offset),
|
||||
model_.tree_segments.ConstDeviceSpan(),
|
||||
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
|
||||
entry_start, use_shared);
|
||||
batch_offset += batch.Size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
|
||||
@@ -801,7 +904,7 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
|
||||
std::mutex lock_;
|
||||
DeviceModel model_;
|
||||
size_t max_shared_memory_bytes_;
|
||||
size_t max_shared_memory_bytes_ { 0 };
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
Reference in New Issue
Block a user