Support column split in GPU predictor (#9343)

This commit is contained in:
Rong Ou
2023-07-02 13:05:34 -07:00
committed by GitHub
parent f90771eec6
commit 3a0f787703
5 changed files with 288 additions and 25 deletions

View File

@@ -122,10 +122,11 @@ template <typename Func>
void RunBitwiseAllreduce(char *out_buffer, char const *device_buffer, Func func, int world_size,
std::size_t size, cudaStream_t stream) {
dh::LaunchN(size, stream, [=] __device__(std::size_t idx) {
out_buffer[idx] = device_buffer[idx];
auto result = device_buffer[idx];
for (auto rank = 1; rank < world_size; rank++) {
out_buffer[idx] = func(out_buffer[idx], device_buffer[rank * size + idx]);
result = func(result, device_buffer[rank * size + idx]);
}
out_buffer[idx] = result;
});
}
} // anonymous namespace