Add tests for gpu_approx (#9553)
This commit is contained in:
parent
6c791b5b47
commit
66a0832778
@ -120,6 +120,11 @@ TEST_P(VerticalFederatedLearnerTest, Hist) {
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
TEST_P(VerticalFederatedLearnerTest, GPUApprox) {
|
||||
std::string objective = GetParam();
|
||||
this->Run("approx", "cuda:0", objective);
|
||||
}
|
||||
|
||||
TEST_P(VerticalFederatedLearnerTest, GPUHist) {
|
||||
std::string objective = GetParam();
|
||||
this->Run("hist", "cuda:0", objective);
|
||||
|
||||
@ -428,7 +428,7 @@ TEST(GpuHist, MaxDepth) {
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree GetUpdatedTree(Context const* ctx, DMatrix* dmat) {
|
||||
RegTree GetHistTree(Context const* ctx, DMatrix* dmat) {
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
GPUHistMaker hist_maker{ctx, &task};
|
||||
hist_maker.Configure(Args{});
|
||||
@ -446,7 +446,7 @@ RegTree GetUpdatedTree(Context const* ctx, DMatrix* dmat) {
|
||||
return tree;
|
||||
}
|
||||
|
||||
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||
void VerifyHistColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||
Context ctx(MakeCUDACtx(GPUIDX));
|
||||
|
||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||
@ -454,7 +454,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expect
|
||||
auto const rank = collective::GetRank();
|
||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
||||
|
||||
RegTree tree = GetUpdatedTree(&ctx, sliced.get());
|
||||
RegTree tree = GetHistTree(&ctx, sliced.get());
|
||||
|
||||
Json json{Object{}};
|
||||
tree.SaveModel(&json);
|
||||
@ -472,8 +472,58 @@ TEST_F(MGPUHistTest, GPUHistColumnSplit) {
|
||||
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
RegTree expected_tree = GetUpdatedTree(&ctx, dmat.get());
|
||||
RegTree expected_tree = GetHistTree(&ctx, dmat.get());
|
||||
|
||||
DoTest(VerifyColumnSplit, kRows, kCols, expected_tree);
|
||||
DoTest(VerifyHistColumnSplit, kRows, kCols, expected_tree);
|
||||
}
|
||||
|
||||
namespace {
|
||||
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
GPUGlobalApproxMaker approx_maker{ctx, &task};
|
||||
approx_maker.Configure(Args{});
|
||||
|
||||
TrainParam param;
|
||||
param.UpdateAllowUnknown(Args{});
|
||||
|
||||
linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Ordinal());
|
||||
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));
|
||||
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
RegTree tree;
|
||||
approx_maker.Update(¶m, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||
{&tree});
|
||||
return tree;
|
||||
}
|
||||
|
||||
void VerifyApproxColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||
Context ctx(MakeCUDACtx(GPUIDX));
|
||||
|
||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||
auto const world_size = collective::GetWorldSize();
|
||||
auto const rank = collective::GetRank();
|
||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
||||
|
||||
RegTree tree = GetApproxTree(&ctx, sliced.get());
|
||||
|
||||
Json json{Object{}};
|
||||
tree.SaveModel(&json);
|
||||
Json expected_json{Object{}};
|
||||
expected_tree.SaveModel(&expected_json);
|
||||
ASSERT_EQ(json, expected_json);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class MGPUApproxTest : public BaseMGPUTest {};
|
||||
|
||||
TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
|
||||
auto constexpr kRows = 32;
|
||||
auto constexpr kCols = 16;
|
||||
|
||||
Context ctx(MakeCUDACtx(0));
|
||||
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
RegTree expected_tree = GetApproxTree(&ctx, dmat.get());
|
||||
|
||||
DoTest(VerifyApproxColumnSplit, kRows, kCols, expected_tree);
|
||||
}
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user