Support column split in gpu hist updater (#9384)

This commit is contained in:
Rong Ou
2023-08-31 03:09:35 -07:00
committed by GitHub
parent ccfc90e4c6
commit 9bab06cbca
10 changed files with 187 additions and 28 deletions

View File

@@ -30,7 +30,7 @@ void TestUpdatePositionBatch() {
std::vector<int> extra_data = {0};
// Send the first five training instances to the right node
// and the second 5 to the left node
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int) {
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int, int) {
return ridx > 4;
});
rows = rp.GetRowsHost(1);
@@ -43,7 +43,7 @@ void TestUpdatePositionBatch() {
}
// Split the left node again
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int) {
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int, int) {
return ridx < 7;
});
EXPECT_EQ(rp.GetRows(3).size(), 2);
@@ -57,7 +57,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
thrust::device_vector<uint32_t> ridx_tmp(ridx_in.size());
thrust::device_vector<bst_uint> counts(segments.size());
auto op = [=] __device__(auto ridx, int data) { return ridx % 2 == 0; };
auto op = [=] __device__(auto ridx, int split_index, int data) { return ridx % 2 == 0; };
std::vector<int> op_data(segments.size());
std::vector<PerNodeData<int>> h_batch_info(segments.size());
dh::TemporaryArray<PerNodeData<int>> d_batch_info(segments.size());