/** * Copyright 2024, XGBoost Contributors */ #include "test_column_split.h" #include #include // for RegTree #include // for TreeUpdater #include // for hardware_concurrency #include // for vector #include "../../../src/tree/param.h" // for TrainParam #include "../collective/test_worker.h" // for TestDistributedGlobal namespace xgboost::tree { void TestColumnSplit(bst_target_t n_targets, bool categorical, std::string name, float sparsity) { auto constexpr kRows = 32; auto constexpr kCols = 16; RegTree expected_tree{n_targets, static_cast(kCols)}; ObjInfo task{ObjInfo::kRegression}; Context ctx; { auto p_dmat = GenerateCatDMatrix(kRows, kCols, sparsity, categorical); auto gpair = GenerateRandomGradients(&ctx, kRows, n_targets); std::unique_ptr updater{TreeUpdater::Create(name, &ctx, &task)}; std::vector> position(1); TrainParam param; param.Init(Args{}); updater->Configure(Args{}); updater->Update(¶m, &gpair, p_dmat.get(), position, {&expected_tree}); } auto constexpr kWorldSize = 2; auto verify = [&] { Context ctx; ctx.UpdateAllowUnknown( Args{{"nthread", std::to_string(collective::GetWorkerLocalThreads(kWorldSize))}}); auto p_dmat = GenerateCatDMatrix(kRows, kCols, sparsity, categorical); auto gpair = GenerateRandomGradients(&ctx, kRows, n_targets); ObjInfo task{ObjInfo::kRegression}; std::unique_ptr updater{TreeUpdater::Create(name, &ctx, &task)}; std::vector> position(1); std::unique_ptr sliced{ p_dmat->SliceCol(collective::GetWorldSize(), collective::GetRank())}; RegTree tree{n_targets, static_cast(kCols)}; TrainParam param; param.Init(Args{}); updater->Configure(Args{}); updater->Update(¶m, &gpair, sliced.get(), position, {&tree}); Json json{Object{}}; tree.SaveModel(&json); Json expected_json{Object{}}; expected_tree.SaveModel(&expected_json); ASSERT_EQ(json, expected_json); }; collective::TestDistributedGlobal(kWorldSize, [&] { verify(); }); } } // namespace xgboost::tree