More support for column split in gpu predictor (#9562)
This commit is contained in:
parent
a343ae3b34
commit
d8c3cc92ae
@ -633,11 +633,12 @@ __global__ void MaskBitVectorKernel(
|
|||||||
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
|
common::Span<std::uint32_t const> d_categories, BitVector decision_bits, BitVector missing_bits,
|
||||||
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows,
|
std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows,
|
||||||
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) {
|
std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) {
|
||||||
|
// This needs to be always instantiated since the data is loaded cooperatively by all threads.
|
||||||
|
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||||
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
if (row_idx >= num_rows) {
|
if (row_idx >= num_rows) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
|
||||||
|
|
||||||
std::size_t tree_offset = 0;
|
std::size_t tree_offset = 0;
|
||||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
@ -668,7 +669,7 @@ __global__ void MaskBitVectorKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
|
__device__ bst_node_t GetLeafIndexByBitVector(bst_row_t ridx, TreeView const& tree,
|
||||||
BitVector const& decision_bits,
|
BitVector const& decision_bits,
|
||||||
BitVector const& missing_bits, std::size_t num_nodes,
|
BitVector const& missing_bits, std::size_t num_nodes,
|
||||||
std::size_t tree_offset) {
|
std::size_t tree_offset) {
|
||||||
@ -683,9 +684,19 @@ __device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
|
|||||||
}
|
}
|
||||||
n = tree.d_tree[nidx];
|
n = tree.d_tree[nidx];
|
||||||
}
|
}
|
||||||
|
return nidx;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float GetLeafWeightByBitVector(bst_row_t ridx, TreeView const& tree,
|
||||||
|
BitVector const& decision_bits,
|
||||||
|
BitVector const& missing_bits, std::size_t num_nodes,
|
||||||
|
std::size_t tree_offset) {
|
||||||
|
auto const nidx =
|
||||||
|
GetLeafIndexByBitVector(ridx, tree, decision_bits, missing_bits, num_nodes, tree_offset);
|
||||||
return tree.d_tree[nidx].LeafValue();
|
return tree.d_tree[nidx].LeafValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <bool predict_leaf>
|
||||||
__global__ void PredictByBitVectorKernel(
|
__global__ void PredictByBitVectorKernel(
|
||||||
common::Span<RegTree::Node const> d_nodes, common::Span<float> d_out_predictions,
|
common::Span<RegTree::Node const> d_nodes, common::Span<float> d_out_predictions,
|
||||||
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
|
common::Span<std::size_t const> d_tree_segments, common::Span<int const> d_tree_group,
|
||||||
@ -701,6 +712,17 @@ __global__ void PredictByBitVectorKernel(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::size_t tree_offset = 0;
|
std::size_t tree_offset = 0;
|
||||||
|
if constexpr (predict_leaf) {
|
||||||
|
for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||||
|
TreeView d_tree{tree_begin, tree_idx, d_nodes,
|
||||||
|
d_tree_segments, d_tree_split_types, d_cat_tree_segments,
|
||||||
|
d_cat_node_segments, d_categories};
|
||||||
|
auto const leaf = GetLeafIndexByBitVector(row_idx, d_tree, decision_bits, missing_bits,
|
||||||
|
num_nodes, tree_offset);
|
||||||
|
d_out_predictions[row_idx * (tree_end - tree_begin) + tree_idx] = static_cast<float>(leaf);
|
||||||
|
tree_offset += d_tree.d_tree.size();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
if (num_group == 1) {
|
if (num_group == 1) {
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
@ -725,6 +747,7 @@ __global__ void PredictByBitVectorKernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
class ColumnSplitHelper {
|
class ColumnSplitHelper {
|
||||||
public:
|
public:
|
||||||
@ -733,13 +756,21 @@ class ColumnSplitHelper {
|
|||||||
void PredictBatch(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
void PredictBatch(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||||
gbm::GBTreeModel const& model, DeviceModel const& d_model) const {
|
gbm::GBTreeModel const& model, DeviceModel const& d_model) const {
|
||||||
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
|
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
|
||||||
PredictDMatrix(dmat, out_preds, d_model, model.learner_model_param->num_feature,
|
PredictDMatrix<false>(dmat, out_preds, d_model, model.learner_model_param->num_feature,
|
||||||
|
model.learner_model_param->num_output_group);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds, gbm::GBTreeModel const& model,
|
||||||
|
DeviceModel const& d_model) const {
|
||||||
|
CHECK(dmat->PageExists<SparsePage>()) << "Column split for external memory is not support.";
|
||||||
|
PredictDMatrix<true>(dmat, out_preds, d_model, model.learner_model_param->num_feature,
|
||||||
model.learner_model_param->num_output_group);
|
model.learner_model_param->num_output_group);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using BitType = BitVector::value_type;
|
using BitType = BitVector::value_type;
|
||||||
|
|
||||||
|
template <bool predict_leaf>
|
||||||
void PredictDMatrix(DMatrix* dmat, HostDeviceVector<float>* out_preds, DeviceModel const& model,
|
void PredictDMatrix(DMatrix* dmat, HostDeviceVector<float>* out_preds, DeviceModel const& model,
|
||||||
bst_feature_t num_features, std::uint32_t num_group) const {
|
bst_feature_t num_features, std::uint32_t num_group) const {
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
@ -777,7 +808,7 @@ class ColumnSplitHelper {
|
|||||||
AllReduceBitVectors(&decision_storage, &missing_storage);
|
AllReduceBitVectors(&decision_storage, &missing_storage);
|
||||||
|
|
||||||
dh::LaunchKernel {grid, kBlockThreads, 0, ctx_->CUDACtx()->Stream()} (
|
dh::LaunchKernel {grid, kBlockThreads, 0, ctx_->CUDACtx()->Stream()} (
|
||||||
PredictByBitVectorKernel, model.nodes.ConstDeviceSpan(),
|
PredictByBitVectorKernel<predict_leaf>, model.nodes.ConstDeviceSpan(),
|
||||||
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(),
|
||||||
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(),
|
||||||
model.categories_tree_segments.ConstDeviceSpan(),
|
model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
@ -795,7 +826,6 @@ class ColumnSplitHelper {
|
|||||||
ctx_->gpu_id, decision_storage->data().get(), decision_storage->size());
|
ctx_->gpu_id, decision_storage->data().get(), decision_storage->size());
|
||||||
collective::AllReduce<collective::Operation::kBitwiseAND>(
|
collective::AllReduce<collective::Operation::kBitwiseAND>(
|
||||||
ctx_->gpu_id, missing_storage->data().get(), missing_storage->size());
|
ctx_->gpu_id, missing_storage->data().get(), missing_storage->size());
|
||||||
collective::Synchronize(ctx_->gpu_id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,
|
void ResizeBitVectors(dh::caching_device_vector<BitType>* decision_storage,
|
||||||
@ -889,7 +919,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
DeviceModel d_model;
|
DeviceModel d_model;
|
||||||
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);
|
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);
|
||||||
|
|
||||||
if (dmat->Info().IsColumnSplit()) {
|
if (info.IsColumnSplit()) {
|
||||||
column_split_helper_.PredictBatch(dmat, out_preds, model, d_model);
|
column_split_helper_.PredictBatch(dmat, out_preds, model, d_model);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1018,6 +1048,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
if (tree_weights != nullptr) {
|
if (tree_weights != nullptr) {
|
||||||
LOG(FATAL) << "Dart booster feature " << not_implemented;
|
LOG(FATAL) << "Dart booster feature " << not_implemented;
|
||||||
}
|
}
|
||||||
|
CHECK(!p_fmat->Info().IsColumnSplit())
|
||||||
|
<< "Predict contribution support for column-wise data split is not yet implemented.";
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
out_contribs->SetDevice(ctx_->gpu_id);
|
out_contribs->SetDevice(ctx_->gpu_id);
|
||||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||||
@ -1136,17 +1169,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
const gbm::GBTreeModel &model,
|
const gbm::GBTreeModel &model,
|
||||||
unsigned tree_end) const override {
|
unsigned tree_end) const override {
|
||||||
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
|
||||||
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
|
|
||||||
|
|
||||||
const MetaInfo& info = p_fmat->Info();
|
const MetaInfo& info = p_fmat->Info();
|
||||||
constexpr uint32_t kBlockThreads = 128;
|
|
||||||
size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
|
|
||||||
info.num_col_, max_shared_memory_bytes);
|
|
||||||
bool use_shared = shared_memory_bytes != 0;
|
|
||||||
bst_feature_t num_features = info.num_col_;
|
|
||||||
bst_row_t num_rows = info.num_row_;
|
bst_row_t num_rows = info.num_row_;
|
||||||
size_t entry_start = 0;
|
|
||||||
|
|
||||||
if (tree_end == 0 || tree_end > model.trees.size()) {
|
if (tree_end == 0 || tree_end > model.trees.size()) {
|
||||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
tree_end = static_cast<uint32_t>(model.trees.size());
|
||||||
}
|
}
|
||||||
@ -1155,6 +1180,19 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
DeviceModel d_model;
|
DeviceModel d_model;
|
||||||
d_model.Init(model, 0, tree_end, this->ctx_->gpu_id);
|
d_model.Init(model, 0, tree_end, this->ctx_->gpu_id);
|
||||||
|
|
||||||
|
if (info.IsColumnSplit()) {
|
||||||
|
column_split_helper_.PredictLeaf(p_fmat, predictions, model, d_model);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
|
||||||
|
constexpr uint32_t kBlockThreads = 128;
|
||||||
|
size_t shared_memory_bytes = SharedMemoryBytes<kBlockThreads>(
|
||||||
|
info.num_col_, max_shared_memory_bytes);
|
||||||
|
bool use_shared = shared_memory_bytes != 0;
|
||||||
|
bst_feature_t num_features = info.num_col_;
|
||||||
|
size_t entry_start = 0;
|
||||||
|
|
||||||
if (p_fmat->PageExists<SparsePage>()) {
|
if (p_fmat->PageExists<SparsePage>()) {
|
||||||
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
|
||||||
batch.data.SetDevice(ctx_->gpu_id);
|
batch.data.SetDevice(ctx_->gpu_id);
|
||||||
|
|||||||
@ -127,8 +127,8 @@ TEST(CpuPredictor, IterationRange) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, IterationRangeColmnSplit) {
|
TEST(CpuPredictor, IterationRangeColmnSplit) {
|
||||||
Context ctx;
|
auto constexpr kWorldSize = 2;
|
||||||
TestIterationRangeColumnSplit(&ctx);
|
TestIterationRangeColumnSplit(kWorldSize, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, ExternalMemory) {
|
TEST(CpuPredictor, ExternalMemory) {
|
||||||
@ -226,23 +226,21 @@ TEST(CPUPredictor, GHistIndexTraining) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUPredictor, CategoricalPrediction) {
|
TEST(CPUPredictor, CategoricalPrediction) {
|
||||||
Context ctx;
|
TestCategoricalPrediction(false, false);
|
||||||
TestCategoricalPrediction(&ctx, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
|
TEST(CPUPredictor, CategoricalPredictionColumnSplit) {
|
||||||
Context ctx;
|
auto constexpr kWorldSize = 2;
|
||||||
TestCategoricalPredictionColumnSplit(&ctx);
|
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
TEST(CPUPredictor, CategoricalPredictLeaf) {
|
||||||
Context ctx;
|
TestCategoricalPredictLeaf(false, false);
|
||||||
TestCategoricalPredictLeaf(&ctx, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
|
TEST(CPUPredictor, CategoricalPredictLeafColumnSplit) {
|
||||||
Context ctx;
|
auto constexpr kWorldSize = 2;
|
||||||
TestCategoricalPredictLeafColumnSplit(&ctx);
|
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, false, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||||
@ -256,8 +254,8 @@ TEST(CpuPredictor, LesserFeatures) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, LesserFeaturesColumnSplit) {
|
TEST(CpuPredictor, LesserFeaturesColumnSplit) {
|
||||||
Context ctx;
|
auto constexpr kWorldSize = 2;
|
||||||
TestPredictionWithLesserFeaturesColumnSplit(&ctx);
|
RunWithInMemoryCommunicator(kWorldSize, TestPredictionWithLesserFeaturesColumnSplit, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, Sparse) {
|
TEST(CpuPredictor, Sparse) {
|
||||||
@ -267,9 +265,9 @@ TEST(CpuPredictor, Sparse) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, SparseColumnSplit) {
|
TEST(CpuPredictor, SparseColumnSplit) {
|
||||||
Context ctx;
|
auto constexpr kWorldSize = 2;
|
||||||
TestSparsePredictionColumnSplit(&ctx, 0.2);
|
TestSparsePredictionColumnSplit(kWorldSize, false, 0.2);
|
||||||
TestSparsePredictionColumnSplit(&ctx, 0.8);
|
TestSparsePredictionColumnSplit(kWorldSize, false, 0.8);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CpuPredictor, Multi) {
|
TEST(CpuPredictor, Multi) {
|
||||||
|
|||||||
@ -206,6 +206,10 @@ TEST(GpuPredictor, LesserFeatures) {
|
|||||||
TestPredictionWithLesserFeatures(&ctx);
|
TestPredictionWithLesserFeatures(&ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MGPUPredictorTest, LesserFeaturesColumnSplit) {
|
||||||
|
RunWithInMemoryCommunicator(world_size_, TestPredictionWithLesserFeaturesColumnSplit, true);
|
||||||
|
}
|
||||||
|
|
||||||
// Very basic test of empty model
|
// Very basic test of empty model
|
||||||
TEST(GPUPredictor, ShapStump) {
|
TEST(GPUPredictor, ShapStump) {
|
||||||
cudaSetDevice(0);
|
cudaSetDevice(0);
|
||||||
@ -270,14 +274,24 @@ TEST(GPUPredictor, IterationRange) {
|
|||||||
TestIterationRange(&ctx);
|
TestIterationRange(&ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MGPUPredictorTest, IterationRangeColumnSplit) {
|
||||||
|
TestIterationRangeColumnSplit(world_size_, true);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(GPUPredictor, CategoricalPrediction) {
|
TEST(GPUPredictor, CategoricalPrediction) {
|
||||||
auto ctx = MakeCUDACtx(0);
|
TestCategoricalPrediction(true, false);
|
||||||
TestCategoricalPrediction(&ctx, false);
|
}
|
||||||
|
|
||||||
|
TEST_F(MGPUPredictorTest, CategoricalPredictionColumnSplit) {
|
||||||
|
RunWithInMemoryCommunicator(world_size_, TestCategoricalPrediction, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUPredictor, CategoricalPredictLeaf) {
|
TEST(GPUPredictor, CategoricalPredictLeaf) {
|
||||||
auto ctx = MakeCUDACtx(0);
|
TestCategoricalPredictLeaf(true, false);
|
||||||
TestCategoricalPredictLeaf(&ctx, false);
|
}
|
||||||
|
|
||||||
|
TEST_F(MGPUPredictorTest, CategoricalPredictionLeafColumnSplit) {
|
||||||
|
RunWithInMemoryCommunicator(world_size_, TestCategoricalPredictLeaf, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GPUPredictor, PredictLeafBasic) {
|
TEST(GPUPredictor, PredictLeafBasic) {
|
||||||
@ -305,4 +319,9 @@ TEST(GPUPredictor, Sparse) {
|
|||||||
TestSparsePrediction(&ctx, 0.2);
|
TestSparsePrediction(&ctx, 0.2);
|
||||||
TestSparsePrediction(&ctx, 0.8);
|
TestSparsePrediction(&ctx, 0.8);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MGPUPredictorTest, SparseColumnSplit) {
|
||||||
|
TestSparsePredictionColumnSplit(world_size_, true, 0.2);
|
||||||
|
TestSparsePredictionColumnSplit(world_size_, true, 0.8);
|
||||||
|
}
|
||||||
} // namespace xgboost::predictor
|
} // namespace xgboost::predictor
|
||||||
|
|||||||
@ -172,16 +172,6 @@ void VerifyPredictionWithLesserFeatures(Learner *learner, bst_row_t kRows,
|
|||||||
ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
|
ASSERT_THROW({ learner->Predict(m_invalid, false, &prediction, 0, 0); }, dmlc::Error);
|
||||||
}
|
}
|
||||||
|
|
||||||
void VerifyPredictionWithLesserFeaturesColumnSplit(Learner *learner, size_t rows,
|
|
||||||
std::shared_ptr<DMatrix> m_test,
|
|
||||||
std::shared_ptr<DMatrix> m_invalid) {
|
|
||||||
auto const world_size = collective::GetWorldSize();
|
|
||||||
auto const rank = collective::GetRank();
|
|
||||||
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
|
|
||||||
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
|
|
||||||
|
|
||||||
VerifyPredictionWithLesserFeatures(learner, rows, sliced_test, sliced_invalid);
|
|
||||||
}
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void TestPredictionWithLesserFeatures(Context const *ctx) {
|
void TestPredictionWithLesserFeatures(Context const *ctx) {
|
||||||
@ -229,16 +219,24 @@ void TestPredictionDeviceAccess() {
|
|||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestPredictionWithLesserFeaturesColumnSplit(Context const *ctx) {
|
void TestPredictionWithLesserFeaturesColumnSplit(bool use_gpu) {
|
||||||
size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
auto const world_size = collective::GetWorldSize();
|
||||||
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).GenerateDMatrix(true);
|
auto const rank = collective::GetRank();
|
||||||
auto learner = LearnerForTest(ctx, m_train, kIters);
|
|
||||||
|
std::size_t constexpr kRows = 256, kTrainCols = 256, kTestCols = 4, kIters = 4;
|
||||||
|
auto m_train = RandomDataGenerator(kRows, kTrainCols, 0.5).Seed(rank).GenerateDMatrix(true);
|
||||||
|
Context ctx;
|
||||||
|
if (use_gpu) {
|
||||||
|
ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : rank);
|
||||||
|
}
|
||||||
|
auto learner = LearnerForTest(&ctx, m_train, kIters);
|
||||||
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
auto m_test = RandomDataGenerator(kRows, kTestCols, 0.5).GenerateDMatrix(false);
|
||||||
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
auto m_invalid = RandomDataGenerator(kRows, kTrainCols + 1, 0.5).GenerateDMatrix(false);
|
||||||
|
|
||||||
auto constexpr kWorldSize = 2;
|
std::shared_ptr<DMatrix> sliced_test{m_test->SliceCol(world_size, rank)};
|
||||||
RunWithInMemoryCommunicator(kWorldSize, VerifyPredictionWithLesserFeaturesColumnSplit,
|
std::shared_ptr<DMatrix> sliced_invalid{m_invalid->SliceCol(world_size, rank)};
|
||||||
learner.get(), kRows, m_test, m_invalid);
|
|
||||||
|
VerifyPredictionWithLesserFeatures(learner.get(), kRows, sliced_test, sliced_invalid);
|
||||||
}
|
}
|
||||||
|
|
||||||
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
||||||
@ -260,7 +258,11 @@ void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind,
|
|||||||
model->CommitModelGroup(std::move(trees), 0);
|
model->CommitModelGroup(std::move(trees), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
|
void TestCategoricalPrediction(bool use_gpu, bool is_column_split) {
|
||||||
|
Context ctx;
|
||||||
|
if (use_gpu) {
|
||||||
|
ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : collective::GetRank());
|
||||||
|
}
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
|
||||||
@ -270,10 +272,10 @@ void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
|
|||||||
float left_weight = 1.3f;
|
float left_weight = 1.3f;
|
||||||
float right_weight = 1.7f;
|
float right_weight = 1.7f;
|
||||||
|
|
||||||
gbm::GBTreeModel model(&mparam, ctx);
|
gbm::GBTreeModel model(&mparam, &ctx);
|
||||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||||
|
|
||||||
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(ctx)};
|
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(&ctx)};
|
||||||
|
|
||||||
std::vector<float> row(kCols);
|
std::vector<float> row(kCols);
|
||||||
row[split_ind] = split_cat;
|
row[split_ind] = split_cat;
|
||||||
@ -303,12 +305,11 @@ void TestCategoricalPrediction(Context const* ctx, bool is_column_split) {
|
|||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + score);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalPredictionColumnSplit(Context const *ctx) {
|
void TestCategoricalPredictLeaf(bool use_gpu, bool is_column_split) {
|
||||||
auto constexpr kWorldSize = 2;
|
Context ctx;
|
||||||
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPrediction, ctx, true);
|
if (use_gpu) {
|
||||||
|
ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : collective::GetRank());
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalPredictLeaf(Context const *ctx, bool is_column_split) {
|
|
||||||
size_t constexpr kCols = 10;
|
size_t constexpr kCols = 10;
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
|
||||||
@ -319,10 +320,10 @@ void TestCategoricalPredictLeaf(Context const *ctx, bool is_column_split) {
|
|||||||
float left_weight = 1.3f;
|
float left_weight = 1.3f;
|
||||||
float right_weight = 1.7f;
|
float right_weight = 1.7f;
|
||||||
|
|
||||||
gbm::GBTreeModel model(&mparam, ctx);
|
gbm::GBTreeModel model(&mparam, &ctx);
|
||||||
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);
|
||||||
|
|
||||||
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(ctx)};
|
std::unique_ptr<Predictor> predictor{CreatePredictorForTest(&ctx)};
|
||||||
|
|
||||||
std::vector<float> row(kCols);
|
std::vector<float> row(kCols);
|
||||||
row[split_ind] = split_cat;
|
row[split_ind] = split_cat;
|
||||||
@ -347,11 +348,6 @@ void TestCategoricalPredictLeaf(Context const *ctx, bool is_column_split) {
|
|||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestCategoricalPredictLeafColumnSplit(Context const *ctx) {
|
|
||||||
auto constexpr kWorldSize = 2;
|
|
||||||
RunWithInMemoryCommunicator(kWorldSize, TestCategoricalPredictLeaf, ctx, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestIterationRange(Context const* ctx) {
|
void TestIterationRange(Context const* ctx) {
|
||||||
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0)
|
auto dmat = RandomDataGenerator(kRows, kCols, 0)
|
||||||
@ -411,15 +407,30 @@ void TestIterationRange(Context const* ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *sliced,
|
void VerifyIterationRangeColumnSplit(bool use_gpu, Json const &ranged_model,
|
||||||
|
Json const &sliced_model, std::size_t rows, std::size_t cols,
|
||||||
|
std::size_t classes,
|
||||||
std::vector<float> const &expected_margin_ranged,
|
std::vector<float> const &expected_margin_ranged,
|
||||||
std::vector<float> const &expected_margin_sliced,
|
std::vector<float> const &expected_margin_sliced,
|
||||||
std::vector<float> const &expected_leaf_ranged,
|
std::vector<float> const &expected_leaf_ranged,
|
||||||
std::vector<float> const &expected_leaf_sliced) {
|
std::vector<float> const &expected_leaf_sliced) {
|
||||||
auto const world_size = collective::GetWorldSize();
|
auto const world_size = collective::GetWorldSize();
|
||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
|
Context ctx;
|
||||||
|
if (use_gpu) {
|
||||||
|
ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : rank);
|
||||||
|
}
|
||||||
|
auto dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix(true, true, classes);
|
||||||
std::shared_ptr<DMatrix> Xy{dmat->SliceCol(world_size, rank)};
|
std::shared_ptr<DMatrix> Xy{dmat->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({Xy})};
|
||||||
|
learner->SetParam("device", ctx.DeviceName());
|
||||||
|
learner->LoadModel(ranged_model);
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> sliced{Learner::Create({Xy})};
|
||||||
|
sliced->SetParam("device", ctx.DeviceName());
|
||||||
|
sliced->LoadModel(sliced_model);
|
||||||
|
|
||||||
HostDeviceVector<float> out_predt_sliced;
|
HostDeviceVector<float> out_predt_sliced;
|
||||||
HostDeviceVector<float> out_predt_ranged;
|
HostDeviceVector<float> out_predt_ranged;
|
||||||
|
|
||||||
@ -428,11 +439,15 @@ void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *s
|
|||||||
sliced->Predict(Xy, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
|
sliced->Predict(Xy, true, &out_predt_sliced, 0, 0, false, false, false, false, false);
|
||||||
learner->Predict(Xy, true, &out_predt_ranged, 0, 3, false, false, false, false, false);
|
learner->Predict(Xy, true, &out_predt_ranged, 0, 3, false, false, false, false, false);
|
||||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
auto const &h_range = out_predt_ranged.HostVector();
|
auto const &h_ranged = out_predt_ranged.HostVector();
|
||||||
ASSERT_EQ(h_sliced.size(), expected_margin_sliced.size());
|
EXPECT_EQ(h_sliced.size(), expected_margin_sliced.size());
|
||||||
ASSERT_EQ(h_sliced, expected_margin_sliced);
|
for (std::size_t i = 0; i < expected_margin_sliced.size(); ++i) {
|
||||||
ASSERT_EQ(h_range.size(), expected_margin_ranged.size());
|
ASSERT_FLOAT_EQ(h_sliced[i], expected_margin_sliced[i]) << "rank " << rank << ", i " << i;
|
||||||
ASSERT_EQ(h_range, expected_margin_ranged);
|
}
|
||||||
|
EXPECT_EQ(h_ranged.size(), expected_margin_ranged.size());
|
||||||
|
for (std::size_t i = 0; i < expected_margin_ranged.size(); ++i) {
|
||||||
|
ASSERT_FLOAT_EQ(h_ranged[i], expected_margin_ranged[i]) << "rank " << rank << ", i " << i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Leaf
|
// Leaf
|
||||||
@ -440,21 +455,27 @@ void VerifyIterationRangeColumnSplit(DMatrix *dmat, Learner *learner, Learner *s
|
|||||||
sliced->Predict(Xy, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
|
sliced->Predict(Xy, false, &out_predt_sliced, 0, 0, false, true, false, false, false);
|
||||||
learner->Predict(Xy, false, &out_predt_ranged, 0, 3, false, true, false, false, false);
|
learner->Predict(Xy, false, &out_predt_ranged, 0, 3, false, true, false, false, false);
|
||||||
auto const &h_sliced = out_predt_sliced.HostVector();
|
auto const &h_sliced = out_predt_sliced.HostVector();
|
||||||
auto const &h_range = out_predt_ranged.HostVector();
|
auto const &h_ranged = out_predt_ranged.HostVector();
|
||||||
ASSERT_EQ(h_sliced.size(), expected_leaf_sliced.size());
|
EXPECT_EQ(h_sliced.size(), expected_leaf_sliced.size());
|
||||||
ASSERT_EQ(h_sliced, expected_leaf_sliced);
|
for (std::size_t i = 0; i < expected_leaf_sliced.size(); ++i) {
|
||||||
ASSERT_EQ(h_range.size(), expected_leaf_ranged.size());
|
ASSERT_FLOAT_EQ(h_sliced[i], expected_leaf_sliced[i]) << "rank " << rank << ", i " << i;
|
||||||
ASSERT_EQ(h_range, expected_leaf_ranged);
|
}
|
||||||
|
EXPECT_EQ(h_ranged.size(), expected_leaf_ranged.size());
|
||||||
|
for (std::size_t i = 0; i < expected_leaf_ranged.size(); ++i) {
|
||||||
|
ASSERT_FLOAT_EQ(h_ranged[i], expected_leaf_ranged[i]) << "rank " << rank << ", i " << i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void TestIterationRangeColumnSplit(Context const* ctx) {
|
void TestIterationRangeColumnSplit(int world_size, bool use_gpu) {
|
||||||
size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
std::size_t constexpr kRows = 1000, kCols = 20, kClasses = 4, kForest = 3, kIters = 10;
|
||||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||||
auto learner = LearnerForTest(ctx, dmat, kIters, kForest);
|
Context ctx;
|
||||||
|
if (use_gpu) {
|
||||||
learner->SetParam("device", ctx->DeviceName());
|
ctx = MakeCUDACtx(0);
|
||||||
|
}
|
||||||
|
auto learner = LearnerForTest(&ctx, dmat, kIters, kForest);
|
||||||
|
|
||||||
bool bound = false;
|
bool bound = false;
|
||||||
std::unique_ptr<Learner> sliced{learner->Slice(0, 3, 1, &bound)};
|
std::unique_ptr<Learner> sliced{learner->Slice(0, 3, 1, &bound)};
|
||||||
@ -476,9 +497,13 @@ void TestIterationRangeColumnSplit(Context const* ctx) {
|
|||||||
auto const &leaf_sliced = leaf_predt_sliced.HostVector();
|
auto const &leaf_sliced = leaf_predt_sliced.HostVector();
|
||||||
auto const &leaf_ranged = leaf_predt_ranged.HostVector();
|
auto const &leaf_ranged = leaf_predt_ranged.HostVector();
|
||||||
|
|
||||||
auto constexpr kWorldSize = 2;
|
Json ranged_model{Object{}};
|
||||||
RunWithInMemoryCommunicator(kWorldSize, VerifyIterationRangeColumnSplit, dmat.get(),
|
learner->SaveModel(&ranged_model);
|
||||||
learner.get(), sliced.get(), margin_ranged, margin_sliced,
|
Json sliced_model{Object{}};
|
||||||
|
sliced->SaveModel(&sliced_model);
|
||||||
|
|
||||||
|
RunWithInMemoryCommunicator(world_size, VerifyIterationRangeColumnSplit, use_gpu, ranged_model,
|
||||||
|
sliced_model, kRows, kCols, kClasses, margin_ranged, margin_sliced,
|
||||||
leaf_ranged, leaf_sliced);
|
leaf_ranged, leaf_sliced);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -539,11 +564,20 @@ void TestSparsePrediction(Context const *ctx, float sparsity) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner,
|
void VerifySparsePredictionColumnSplit(bool use_gpu, Json const &model, std::size_t rows,
|
||||||
|
std::size_t cols, float sparsity,
|
||||||
std::vector<float> const &expected_predt) {
|
std::vector<float> const &expected_predt) {
|
||||||
std::shared_ptr<DMatrix> sliced{
|
Context ctx;
|
||||||
dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
if (use_gpu) {
|
||||||
|
ctx = MakeCUDACtx(common::AllVisibleGPUs() == 1 ? 0 : collective::GetRank());
|
||||||
|
}
|
||||||
|
auto Xy = RandomDataGenerator(rows, cols, sparsity).GenerateDMatrix(true);
|
||||||
|
std::shared_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||||
HostDeviceVector<float> sparse_predt;
|
HostDeviceVector<float> sparse_predt;
|
||||||
|
|
||||||
|
std::unique_ptr<Learner> learner{Learner::Create({sliced})};
|
||||||
|
learner->SetParam("device", ctx.DeviceName());
|
||||||
|
learner->LoadModel(model);
|
||||||
learner->Predict(sliced, false, &sparse_predt, 0, 0);
|
learner->Predict(sliced, false, &sparse_predt, 0, 0);
|
||||||
|
|
||||||
auto const &predt = sparse_predt.HostVector();
|
auto const &predt = sparse_predt.HostVector();
|
||||||
@ -554,10 +588,14 @@ void VerifySparsePredictionColumnSplit(DMatrix *dmat, Learner *learner,
|
|||||||
}
|
}
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity) {
|
void TestSparsePredictionColumnSplit(int world_size, bool use_gpu, float sparsity) {
|
||||||
|
Context ctx;
|
||||||
|
if (use_gpu) {
|
||||||
|
ctx = MakeCUDACtx(0);
|
||||||
|
}
|
||||||
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
|
size_t constexpr kRows = 512, kCols = 128, kIters = 4;
|
||||||
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(true);
|
||||||
auto learner = LearnerForTest(ctx, Xy, kIters);
|
auto learner = LearnerForTest(&ctx, Xy, kIters);
|
||||||
|
|
||||||
HostDeviceVector<float> sparse_predt;
|
HostDeviceVector<float> sparse_predt;
|
||||||
|
|
||||||
@ -567,12 +605,11 @@ void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity) {
|
|||||||
learner.reset(Learner::Create({Xy}));
|
learner.reset(Learner::Create({Xy}));
|
||||||
learner->LoadModel(model);
|
learner->LoadModel(model);
|
||||||
|
|
||||||
learner->SetParam("device", ctx->DeviceName());
|
learner->SetParam("device", ctx.DeviceName());
|
||||||
learner->Predict(Xy, false, &sparse_predt, 0, 0);
|
learner->Predict(Xy, false, &sparse_predt, 0, 0);
|
||||||
|
|
||||||
auto constexpr kWorldSize = 2;
|
RunWithInMemoryCommunicator(world_size, VerifySparsePredictionColumnSplit, use_gpu, model,
|
||||||
RunWithInMemoryCommunicator(kWorldSize, VerifySparsePredictionColumnSplit, Xy.get(),
|
kRows, kCols, sparsity, sparse_predt.HostVector());
|
||||||
learner.get(), sparse_predt.HostVector());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestVectorLeafPrediction(Context const *ctx) {
|
void TestVectorLeafPrediction(Context const *ctx) {
|
||||||
|
|||||||
@ -94,23 +94,19 @@ void TestPredictionWithLesserFeatures(Context const* ctx);
|
|||||||
|
|
||||||
void TestPredictionDeviceAccess();
|
void TestPredictionDeviceAccess();
|
||||||
|
|
||||||
void TestCategoricalPrediction(Context const* ctx, bool is_column_split);
|
void TestCategoricalPrediction(bool use_gpu, bool is_column_split);
|
||||||
|
|
||||||
void TestCategoricalPredictionColumnSplit(Context const* ctx);
|
void TestPredictionWithLesserFeaturesColumnSplit(bool use_gpu);
|
||||||
|
|
||||||
void TestPredictionWithLesserFeaturesColumnSplit(Context const* ctx);
|
void TestCategoricalPredictLeaf(bool use_gpu, bool is_column_split);
|
||||||
|
|
||||||
void TestCategoricalPredictLeaf(Context const* ctx, bool is_column_split);
|
|
||||||
|
|
||||||
void TestCategoricalPredictLeafColumnSplit(Context const* ctx);
|
|
||||||
|
|
||||||
void TestIterationRange(Context const* ctx);
|
void TestIterationRange(Context const* ctx);
|
||||||
|
|
||||||
void TestIterationRangeColumnSplit(Context const* ctx);
|
void TestIterationRangeColumnSplit(int world_size, bool use_gpu);
|
||||||
|
|
||||||
void TestSparsePrediction(Context const* ctx, float sparsity);
|
void TestSparsePrediction(Context const* ctx, float sparsity);
|
||||||
|
|
||||||
void TestSparsePredictionColumnSplit(Context const* ctx, float sparsity);
|
void TestSparsePredictionColumnSplit(int world_size, bool use_gpu, float sparsity);
|
||||||
|
|
||||||
void TestVectorLeafPrediction(Context const* ctx);
|
void TestVectorLeafPrediction(Context const* ctx);
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user