Support column split in GPU predictor (#9343)
This commit is contained in:
parent
f90771eec6
commit
3a0f787703
@ -122,10 +122,11 @@ template <typename Func>
|
||||
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
|
||||
std::size_t size, cudaStream_t stream) {
|
||||
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
|
||||
out_buffer[idx] = device_buffer[idx];
|
||||
auto result = device_buffer[idx];
|
||||
for (auto rank = 1; rank < world_size; rank++) {
|
||||
out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]);
|
||||
result = func(result, device_buffer[rank * size + idx]);
|
||||
}
|
||||
out_buffer[idx] = result;
|
||||
});
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
@ -467,7 +467,6 @@ class ColumnSplitHelper {
|
||||
void MaskOneTree(RegTree::FVec const &feat, std::size_t tree_id, std::size_t row_id) {
|
||||
auto const &tree = *model_.trees[tree_id];
|
||||
auto const &cats = tree.GetCategoriesMatrix();
|
||||
auto const has_categorical = tree.HasCategoricalSplit();
|
||||
bst_node_t n_nodes = tree.GetNodes().size();
|
||||
|
||||
for (bst_node_t nid = 0; nid < n_nodes; nid++) {
|
||||
@ -484,16 +483,10 @@ class ColumnSplitHelper {
|
||||
}
|
||||
|
||||
auto const fvalue = feat.GetFvalue(split_index);
|
||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||
auto const node_categories =
|
||||
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||
if (!common::Decision(node_categories, fvalue)) {
|
||||
decision_bits_.Set(bit_index);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (fvalue >= node.SplitCond()) {
|
||||
auto const decision = tree.HasCategoricalSplit()
|
||||
? GetDecision<true>(node, nid, fvalue, cats)
|
||||
: GetDecision<false>(node, nid, fvalue, cats);
|
||||
if (decision) {
|
||||
decision_bits_.Set(bit_index);
|
||||
}
|
||||
}
|
||||
@ -511,7 +504,7 @@ class ColumnSplitHelper {
|
||||
if (missing_bits_.Check(bit_index)) {
|
||||
return node.DefaultChild();
|
||||
} else {
|
||||
return node.LeftChild() + decision_bits_.Check(bit_index);
|
||||
return node.LeftChild() + !decision_bits_.Check(bit_index);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -11,9 +11,11 @@
|
||||
#include <any> // for any, any_cast
|
||||
#include <memory>
|
||||
|
||||
#include "../collective/communicator-inl.cuh"
|
||||
#include "../common/bitfield.h"
|
||||
#include "../common/categorical.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/cuda_context.cuh"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../data/device_adapter.cuh"
|
||||
#include "../data/ellpack_page.cuh"
|
||||
@ -110,13 +112,11 @@ struct SparsePageLoader {
|
||||
bool use_shared;
|
||||
SparsePageView data;
|
||||
float* smem;
|
||||
size_t entry_start;
|
||||
|
||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||
bst_row_t num_rows, size_t entry_start, float)
|
||||
: use_shared(use_shared),
|
||||
data(data),
|
||||
entry_start(entry_start) {
|
||||
data(data) {
|
||||
extern __shared__ float _smem[];
|
||||
smem = _smem;
|
||||
// Copy instances
|
||||
@ -622,6 +622,199 @@ size_t SharedMemoryBytes(size_t cols, size_t max_shared_memory_bytes) {
|
||||
}
|
||||
return shared_memory_bytes;
|
||||
}
|
||||
|
||||
using BitVector = LBitField64;
|
||||
|
||||
__global__ void MaskBitVectorKernel(
|
||||
SparsePageView data, common::Span<RegTree::Node const> d_nodes,
|
||||
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<std::uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
|
||||
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows,
|
||||
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) {
|
||||
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (row_idx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||
|
||||
std::size_t tree_offset = 0;
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
TreeView d_tree{tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
auto const tree_nodes = d_tree.d_tree.size();
|
||||
for (auto nid = 0; nid < tree_nodes; nid++) {
|
||||
auto const& node = d_tree.d_tree[nid];
|
||||
if (node.IsDeleted() || node.IsLeaf()) {
|
||||
continue;
|
||||
}
|
||||
auto const fvalue = loader.GetElement(row_idx, node.SplitIndex());
|
||||
auto const is_missing = common::CheckNAN(fvalue);
|
||||
auto const bit_index = row_idx * num_nodes + tree_offset + nid;
|
||||
if (is_missing) {
|
||||
missing_bits.Set(bit_index);
|
||||
} else {
|
||||
auto const decision = d_tree.HasCategoricalSplit()
|
||||
? GetDecision<true>(node, nid, fvalue, d_tree.cats)
|
||||
: GetDecision<false>(node, nid, fvalue, d_tree.cats);
|
||||
if (decision) {
|
||||
decision_bits.Set(bit_index);
|
||||
}
|
||||
}
|
||||
}
|
||||
tree_offset += tree_nodes;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
|
||||
BitVector const& decision_bits,
|
||||
BitVector const& missing_bits, std::size_t num_nodes,
|
||||
std::size_t tree_offset) {
|
||||
bst_node_t nidx = 0;
|
||||
RegTree::Node n = tree.d_tree[nidx];
|
||||
while (!n.IsLeaf()) {
|
||||
auto const bit_index = ridx * num_nodes + tree_offset + nidx;
|
||||
if (missing_bits.Check(bit_index)) {
|
||||
nidx = n.DefaultChild();
|
||||
} else {
|
||||
nidx = n.LeftChild() + !decision_bits.Check(bit_index);
|
||||
}
|
||||
n = tree.d_tree[nidx];
|
||||
}
|
||||
return tree.d_tree[nidx].LeafValue();
|
||||
}
|
||||
|
||||
__global__ void PredictByBitVectorKernel(
|
||||
common::Span<RegTree::Node const> d_nodes, common::Span<float> d_out_predictions,
|
||||
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
|
||||
common::Span<FeatureType const> d_tree_split_types,
|
||||
common::Span<std::uint32_t const> d_cat_tree_segments,
|
||||
common::Span<RegTree::CategoricalSplitMatrix::Segment const> d_cat_node_segments,
|
||||
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
|
||||
std::size_t tree_begin, std::size_t tree_end, std::size_t num_rows, std::size_t num_nodes,
|
||||
std::uint32_t num_group) {
|
||||
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (row_idx >= num_rows) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::size_t tree_offset = 0;
|
||||
if (num_group == 1) {
|
||||
float sum = 0;
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
TreeView d_tree{tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
sum += GetLeafWeightByBitVector(row_idx, d_tree, decision_bits, missing_bits, num_nodes,
|
||||
tree_offset);
|
||||
tree_offset += d_tree.d_tree.size();
|
||||
}
|
||||
d_out_predictions[row_idx] += sum;
|
||||
} else {
|
||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto const tree_group = d_tree_group[tree_idx];
|
||||
TreeView d_tree{tree_begin, tree_idx, d_nodes,
|
||||
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||
d_cat_node_segments, d_categories};
|
||||
bst_uint out_prediction_idx = row_idx * num_group + tree_group;
|
||||
d_out_predictions[out_prediction_idx] += GetLeafWeightByBitVector(
|
||||
row_idx, d_tree, decision_bits, missing_bits, num_nodes, tree_offset);
|
||||
tree_offset += d_tree.d_tree.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ColumnSplitHelper {
|
||||
public:
|
||||
explicit ColumnSplitHelper(Context const* ctx) : ctx_{ctx} {}
|
||||
|
||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||
gbm::GBTreeModel const& model, DeviceModel const& d_model) const {
|
||||
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
|
||||
PredictDMatrix(dmat, out_preds, d_model, model.learner_model_param->num_feature,
|
||||
model.learner_model_param->num_output_group);
|
||||
}
|
||||
|
||||
private:
|
||||
using BitType = BitVector::value_type;
|
||||
|
||||
void PredictDMatrix(DMatrix* dmat, HostDeviceVector<float>* out_preds, DeviceModel const& model,
|
||||
bst_feature_t num_features, std::uint32_t num_group) const {
|
||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||
dh::caching_device_vector<BitType> decision_storage{};
|
||||
dh::caching_device_vector<BitType> missing_storage{};
|
||||
|
||||
auto constexpr kBlockThreads = 128;
|
||||
auto const max_shared_memory_bytes = dh::MaxSharedMemory(ctx_->gpu_id);
|
||||
auto const shared_memory_bytes =
|
||||
SharedMemoryBytes<kBlockThreads>(num_features, max_shared_memory_bytes);
|
||||
auto const use_shared = shared_memory_bytes != 0;
|
||||
|
||||
auto const num_nodes = model.nodes.Size();
|
||||
std::size_t batch_offset = 0;
|
||||
for (auto const& batch : dmat->GetBatches<SparsePage>()) {
|
||||
auto const num_rows = batch.Size();
|
||||
ResizeBitVectors(&decision_storage, &missing_storage, num_rows * num_nodes);
|
||||
BitVector decision_bits{dh::ToSpan(decision_storage)};
|
||||
BitVector missing_bits{dh::ToSpan(missing_storage)};
|
||||
|
||||
batch.offset.SetDevice(ctx_->gpu_id);
|
||||
batch.data.SetDevice(ctx_->gpu_id);
|
||||
std::size_t entry_start = 0;
|
||||
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features);
|
||||
|
||||
auto const grid = static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
|
||||
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()} (
|
||||
MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(),
|
||||
model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(),
|
||||
model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
|
||||
decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows,
|
||||
entry_start, num_nodes, use_shared, nan(""));
|
||||
|
||||
AllReduceBitVectors(&decision_storage, &missing_storage);
|
||||
|
||||
dh::LaunchKernel {grid, kBlockThreads, 0, ctx_->CUDACtx()->Stream()} (
|
||||
PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(),
|
||||
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
||||
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
||||
model.categories_tree_segments.ConstDeviceSpan(),
|
||||
model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(),
|
||||
decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_rows, num_nodes,
|
||||
num_group);
|
||||
|
||||
batch_offset += batch.Size() * num_group;
|
||||
}
|
||||
}
|
||||
|
||||
void AllReduceBitVectors(dh::caching_device_vector<BitType>* decision_storage,
|
||||
dh::caching_device_vector<BitType>* missing_storage) const {
|
||||
collective::AllReduce<collective::Operation::kBitwiseOR>(
|
||||
ctx_->gpu_id, decision_storage->data().get(), decision_storage->size());
|
||||
collective::AllReduce<collective::Operation::kBitwiseAND>(
|
||||
ctx_->gpu_id, missing_storage->data().get(), missing_storage->size());
|
||||
collective::Synchronize(ctx_->gpu_id);
|
||||
}
|
||||
|
||||
void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,
|
||||
dh::caching_device_vector<BitType>* missing_storage,
|
||||
std::size_t total_bits) const {
|
||||
auto const size = BitVector::ComputeStorageSize(total_bits);
|
||||
if (decision_storage->size() < size) {
|
||||
decision_storage->resize(size);
|
||||
}
|
||||
thrust::fill(ctx_->CUDACtx()->CTP(), decision_storage->begin(), decision_storage->end(), 0);
|
||||
if (missing_storage->size() < size) {
|
||||
missing_storage->resize(size);
|
||||
}
|
||||
thrust::fill(ctx_->CUDACtx()->CTP(), missing_storage->begin(), missing_storage->end(), 0);
|
||||
}
|
||||
|
||||
Context const* ctx_;
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
class GPUPredictor : public xgboost::Predictor {
|
||||
@ -697,6 +890,11 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
DeviceModel d_model;
|
||||
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);
|
||||
|
||||
if (dmat->Info().IsColumnSplit()) {
|
||||
column_split_helper_.PredictBatch(dmat, out_preds, model, d_model);
|
||||
return;
|
||||
}
|
||||
|
||||
if (dmat->PageExists<SparsePage>()) {
|
||||
size_t batch_offset = 0;
|
||||
for (auto &batch : dmat->GetBatches<SparsePage>()) {
|
||||
@ -720,7 +918,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
public:
|
||||
explicit GPUPredictor(Context const* ctx) : Predictor::Predictor{ctx} {}
|
||||
explicit GPUPredictor(Context const* ctx)
|
||||
: Predictor::Predictor{ctx}, column_split_helper_{ctx} {}
|
||||
|
||||
~GPUPredictor() override {
|
||||
if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
|
||||
@ -1019,6 +1218,8 @@ class GPUPredictor : public xgboost::Predictor {
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
ColumnSplitHelper column_split_helper_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_PREDICTOR(GPUPredictor, "gpu_predictor")
|
||||
|
||||
@ -7,6 +7,18 @@
|
||||
#include "xgboost/tree_model.h"
|
||||
|
||||
namespace xgboost::predictor {
|
||||
/** @brief Whether it should traverse to the left branch of a tree. */
|
||||
template <bool has_categorical>
|
||||
XGBOOST_DEVICE bool GetDecision(RegTree::Node const &node, bst_node_t nid, float fvalue,
|
||||
RegTree::CategoricalSplitMatrix const &cats) {
|
||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||
auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||
return common::Decision(node_categories, fvalue);
|
||||
} else {
|
||||
return fvalue < node.SplitCond();
|
||||
}
|
||||
}
|
||||
|
||||
template <bool has_missing, bool has_categorical>
|
||||
inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bst_node_t nid,
|
||||
float fvalue, bool is_missing,
|
||||
@ -14,13 +26,7 @@ inline XGBOOST_DEVICE bst_node_t GetNextNode(const RegTree::Node &node, const bs
|
||||
if (has_missing && is_missing) {
|
||||
return node.DefaultChild();
|
||||
} else {
|
||||
if (has_categorical && common::IsCat(cats.split_type, nid)) {
|
||||
auto node_categories =
|
||||
cats.categories.subspan(cats.node_ptr[nid].beg, cats.node_ptr[nid].size);
|
||||
return common::Decision(node_categories, fvalue) ? node.LeftChild() : node.RightChild();
|
||||
} else {
|
||||
return node.LeftChild() + !(fvalue < node.SplitCond());
|
||||
}
|
||||
return node.LeftChild() + !GetDecision<has_categorical>(node, nid, fvalue, cats);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -57,6 +57,68 @@ TEST(GPUPredictor, Basic) {
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void VerifyBasicColumnSplit(std::array<std::vector<float>, 32> const& expected_result) {
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
|
||||
auto ctx = MakeCUDACtx(rank);
|
||||
std::unique_ptr<Predictor> predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &ctx));
|
||||
predictor->Configure({});
|
||||
|
||||
for (size_t i = 1; i < 33; i *= 2) {
|
||||
size_t n_row = i, n_col = i;
|
||||
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix();
|
||||
std::unique_ptr<DMatrix> sliced{dmat->SliceCol(world_size, rank)};
|
||||
|
||||
LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)};
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
predictor->InitOutPredictions(sliced->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictBatch(sliced.get(), &out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
EXPECT_EQ(out_predictions_h, expected_result[i - 1]);
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(GPUPredictor, MGPUBasicColumnSplit) {
|
||||
auto const n_gpus = common::AllVisibleGPUs();
|
||||
if (n_gpus <= 1) {
|
||||
GTEST_SKIP() << "Skipping MGPUIBasicColumnSplit test with # GPUs = " << n_gpus;
|
||||
}
|
||||
|
||||
auto ctx = MakeCUDACtx(0);
|
||||
std::unique_ptr<Predictor> predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &ctx));
|
||||
predictor->Configure({});
|
||||
|
||||
std::array<std::vector<float>, 32> result{};
|
||||
for (size_t i = 1; i < 33; i *= 2) {
|
||||
size_t n_row = i, n_col = i;
|
||||
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix();
|
||||
|
||||
LearnerModelParam mparam{MakeMP(n_col, .5, 1, ctx.gpu_id)};
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
|
||||
predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||
predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
result[i - 1] = out_predictions_h;
|
||||
}
|
||||
|
||||
RunWithInMemoryCommunicator(n_gpus, VerifyBasicColumnSplit, result);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackBasic) {
|
||||
size_t constexpr kCols {8};
|
||||
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user