Support column split in gpu hist updater (#9384)
This commit is contained in:
@@ -30,7 +30,7 @@ void TestUpdatePositionBatch() {
|
||||
std::vector<int> extra_data = {0};
|
||||
// Send the first five training instances to the right node
|
||||
// and the second 5 to the left node
|
||||
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int) {
|
||||
rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int, int) {
|
||||
return ridx > 4;
|
||||
});
|
||||
rows = rp.GetRowsHost(1);
|
||||
@@ -43,7 +43,7 @@ void TestUpdatePositionBatch() {
|
||||
}
|
||||
|
||||
// Split the left node again
|
||||
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int) {
|
||||
rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int, int) {
|
||||
return ridx < 7;
|
||||
});
|
||||
EXPECT_EQ(rp.GetRows(3).size(), 2);
|
||||
@@ -57,7 +57,7 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
|
||||
thrust::device_vector<uint32_t> ridx_tmp(ridx_in.size());
|
||||
thrust::device_vector<bst_uint> counts(segments.size());
|
||||
|
||||
auto op = [=] __device__(auto ridx, int data) { return ridx % 2 == 0; };
|
||||
auto op = [=] __device__(auto ridx, int split_index, int data) { return ridx % 2 == 0; };
|
||||
std::vector<int> op_data(segments.size());
|
||||
std::vector<PerNodeData<int>> h_batch_info(segments.size());
|
||||
dh::TemporaryArray<PerNodeData<int>> d_batch_info(segments.size());
|
||||
|
||||
Reference in New Issue
Block a user