Support column split in multi-target hist (#9171)
This commit is contained in:
@@ -194,11 +194,65 @@ void TestColumnSplitPartitioner(bst_target_t n_targets) {
|
||||
|
||||
auto constexpr kWorkers = 4;
|
||||
RunWithInMemoryCommunicator(kWorkers, VerifyColumnSplitPartitioner<ExpandEntry>, n_targets,
|
||||
n_samples, n_features, base_rowid, Xy, min_value, mid_value, mid_partitioner);
|
||||
n_samples, n_features, base_rowid, Xy, min_value, mid_value,
|
||||
mid_partitioner);
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner<CPUExpandEntry>(1); }
|
||||
|
||||
TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); }
|
||||
|
||||
namespace {
|
||||
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
|
||||
RegTree const& expected_tree) {
|
||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||
auto p_gradients = GenerateGradients(rows, n_targets);
|
||||
Context ctx;
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
|
||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||
|
||||
RegTree tree{n_targets, cols};
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||
|
||||
Json json{Object{}};
|
||||
tree.SaveModel(&json);
|
||||
Json expected_json{Object{}};
|
||||
expected_tree.SaveModel(&expected_json);
|
||||
ASSERT_EQ(json, expected_json);
|
||||
}
|
||||
|
||||
void TestColumnSplit(bst_target_t n_targets) {
|
||||
auto constexpr kRows = 32;
|
||||
auto constexpr kCols = 16;
|
||||
|
||||
RegTree expected_tree{n_targets, kCols};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
{
|
||||
auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
auto p_gradients = GenerateGradients(kRows, n_targets);
|
||||
Context ctx;
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), Xy.get(), position, {&expected_tree});
|
||||
}
|
||||
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, n_targets,
|
||||
std::cref(expected_tree));
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(QuantileHist, ColumnSplit) { TestColumnSplit(1); }
|
||||
|
||||
TEST(QuantileHist, ColumnSplitMultiTarget) { TestColumnSplit(3); }
|
||||
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user