Initial support for column-split cpu predictor (#8676)
This commit is contained in:
@@ -4,6 +4,9 @@
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/predictor.h>
|
||||
|
||||
#include <thread>
|
||||
|
||||
#include "../../../src/collective/communicator-inl.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
#include "../../../src/data/proxy_dmatrix.h"
|
||||
#include "../../../src/gbm/gbtree.h"
|
||||
@@ -86,6 +89,49 @@ TEST(CpuPredictor, Basic) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, ColumnSplit) {
|
||||
size_t constexpr kRows = 5;
|
||||
size_t constexpr kCols = 5;
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
||||
|
||||
std::vector<std::thread> threads;
|
||||
size_t constexpr kWorldSize = 2;
|
||||
size_t constexpr kSliceSize = (kCols + 1) / kWorldSize;
|
||||
for (auto rank = 0; rank < kWorldSize; rank++) {
|
||||
threads.emplace_back([=, &dmat]() {
|
||||
Json config{JsonObject()};
|
||||
config["xgboost_communicator"] = String("in-memory");
|
||||
config["in_memory_world_size"] = kWorldSize;
|
||||
config["in_memory_rank"] = rank;
|
||||
xgboost::collective::Init(config);
|
||||
|
||||
auto lparam = CreateEmptyGenericParam(GPUIDX);
|
||||
std::unique_ptr<Predictor> cpu_predictor =
|
||||
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &lparam));
|
||||
|
||||
LearnerModelParam mparam{MakeMP(kCols, .0, 1)};
|
||||
|
||||
Context ctx;
|
||||
ctx.UpdateAllowUnknown(Args{});
|
||||
gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx);
|
||||
|
||||
// Test predict batch
|
||||
PredictionCacheEntry out_predictions;
|
||||
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||
auto sliced = std::unique_ptr<DMatrix>{dmat->SliceCol(rank * kSliceSize, kSliceSize)};
|
||||
cpu_predictor->PredictBatch(sliced.get(), &out_predictions, model, 0);
|
||||
|
||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||
for (size_t i = 0; i < out_predictions.predictions.Size(); i++) {
|
||||
ASSERT_EQ(out_predictions_h[i], 1.5);
|
||||
}
|
||||
xgboost::collective::Finalize();
|
||||
});
|
||||
}
|
||||
for (auto& thread : threads) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, IterationRange) {
|
||||
TestIterationRange("cpu_predictor");
|
||||
|
||||
Reference in New Issue
Block a user