Add tests for gpu_approx (#9553)
This commit is contained in:
parent
6c791b5b47
commit
66a0832778
@ -120,6 +120,11 @@ TEST_P(VerticalFederatedLearnerTest, Hist) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
|
TEST_P(VerticalFederatedLearnerTest, GPUApprox) {
|
||||||
|
std::string objective = GetParam();
|
||||||
|
this->Run("approx", "cuda:0", objective);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_P(VerticalFederatedLearnerTest, GPUHist) {
|
TEST_P(VerticalFederatedLearnerTest, GPUHist) {
|
||||||
std::string objective = GetParam();
|
std::string objective = GetParam();
|
||||||
this->Run("hist", "cuda:0", objective);
|
this->Run("hist", "cuda:0", objective);
|
||||||
|
|||||||
@ -428,7 +428,7 @@ TEST(GpuHist, MaxDepth) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
RegTree GetUpdatedTree(Context const* ctx, DMatrix* dmat) {
|
RegTree GetHistTree(Context const* ctx, DMatrix* dmat) {
|
||||||
ObjInfo task{ObjInfo::kRegression};
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
GPUHistMaker hist_maker{ctx, &task};
|
GPUHistMaker hist_maker{ctx, &task};
|
||||||
hist_maker.Configure(Args{});
|
hist_maker.Configure(Args{});
|
||||||
@ -446,7 +446,7 @@ RegTree GetUpdatedTree(Context const* ctx, DMatrix* dmat) {
|
|||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
void VerifyHistColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||||
Context ctx(MakeCUDACtx(GPUIDX));
|
Context ctx(MakeCUDACtx(GPUIDX));
|
||||||
|
|
||||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||||
@ -454,7 +454,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expect
|
|||||||
auto const rank = collective::GetRank();
|
auto const rank = collective::GetRank();
|
||||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
RegTree tree = GetUpdatedTree(&ctx, sliced.get());
|
RegTree tree = GetHistTree(&ctx, sliced.get());
|
||||||
|
|
||||||
Json json{Object{}};
|
Json json{Object{}};
|
||||||
tree.SaveModel(&json);
|
tree.SaveModel(&json);
|
||||||
@ -472,8 +472,58 @@ TEST_F(MGPUHistTest, GPUHistColumnSplit) {
|
|||||||
|
|
||||||
Context ctx(MakeCUDACtx(0));
|
Context ctx(MakeCUDACtx(0));
|
||||||
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||||
RegTree expected_tree = GetUpdatedTree(&ctx, dmat.get());
|
RegTree expected_tree = GetHistTree(&ctx, dmat.get());
|
||||||
|
|
||||||
DoTest(VerifyColumnSplit, kRows, kCols, expected_tree);
|
DoTest(VerifyHistColumnSplit, kRows, kCols, expected_tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
|
||||||
|
ObjInfo task{ObjInfo::kRegression};
|
||||||
|
GPUGlobalApproxMaker approx_maker{ctx, &task};
|
||||||
|
approx_maker.Configure(Args{});
|
||||||
|
|
||||||
|
TrainParam param;
|
||||||
|
param.UpdateAllowUnknown(Args{});
|
||||||
|
|
||||||
|
linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Ordinal());
|
||||||
|
gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));
|
||||||
|
|
||||||
|
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||||
|
RegTree tree;
|
||||||
|
approx_maker.Update(¶m, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
|
||||||
|
{&tree});
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
void VerifyApproxColumnSplit(bst_row_t rows, bst_feature_t cols, RegTree const& expected_tree) {
|
||||||
|
Context ctx(MakeCUDACtx(GPUIDX));
|
||||||
|
|
||||||
|
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||||
|
auto const world_size = collective::GetWorldSize();
|
||||||
|
auto const rank = collective::GetRank();
|
||||||
|
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};
|
||||||
|
|
||||||
|
RegTree tree = GetApproxTree(&ctx, sliced.get());
|
||||||
|
|
||||||
|
Json json{Object{}};
|
||||||
|
tree.SaveModel(&json);
|
||||||
|
Json expected_json{Object{}};
|
||||||
|
expected_tree.SaveModel(&expected_json);
|
||||||
|
ASSERT_EQ(json, expected_json);
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
class MGPUApproxTest : public BaseMGPUTest {};
|
||||||
|
|
||||||
|
TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
|
||||||
|
auto constexpr kRows = 32;
|
||||||
|
auto constexpr kCols = 16;
|
||||||
|
|
||||||
|
Context ctx(MakeCUDACtx(0));
|
||||||
|
auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||||
|
RegTree expected_tree = GetApproxTree(&ctx, dmat.get());
|
||||||
|
|
||||||
|
DoTest(VerifyApproxColumnSplit, kRows, kCols, expected_tree);
|
||||||
}
|
}
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user