From 309268de0219be73f9db5d5a4d0d89e7e6987844 Mon Sep 17 00:00:00 2001 From: amdsc21 <96135754+amdsc21@users.noreply.github.com> Date: Thu, 9 Mar 2023 22:40:44 +0100 Subject: [PATCH] finish updater_gpu_hist.cu --- src/tree/constraints.cuh | 5 +++ src/tree/updater_gpu_hist.cu | 74 +++++++++++++++++++++++++++++++++++ src/tree/updater_gpu_hist.hip | 4 ++ 3 files changed, 83 insertions(+) diff --git a/src/tree/constraints.cuh b/src/tree/constraints.cuh index 94c262240..bb20c8cf8 100644 --- a/src/tree/constraints.cuh +++ b/src/tree/constraints.cuh @@ -15,7 +15,12 @@ #include "constraints.h" #include "xgboost/span.h" #include "../common/bitfield.h" + +#if defined(XGBOOST_USE_CUDA) #include "../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../common/device_helpers.hip.h" +#endif namespace xgboost { // Feature interaction constraints built for GPU Hist updater. diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 32b3f4a03..d721c40bf 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -15,7 +15,13 @@ #include "../collective/device_communicator.cuh" #include "../common/bitfield.h" #include "../common/categorical.h" + +#if defined(XGBOOST_USE_CUDA) #include "../common/device_helpers.cuh" +#elif defined(XGBOOST_USE_HIP) +#include "../common/device_helpers.hip.h" +#endif + #include "../common/hist_util.h" #include "../common/io.h" #include "../common/timer.h" @@ -235,7 +241,11 @@ struct GPUHistMakerDevice { } ~GPUHistMakerDevice() { // NOLINT +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(ctx_->gpu_id)); +#endif } // Reset values for each update iteration @@ -246,7 +256,11 @@ struct GPUHistMakerDevice { this->column_sampler.Init(ctx_, num_columns, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(ctx_->gpu_id)); +#endif this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id); @@ -256,9 +270,17 @@ struct GPUHistMakerDevice { if (d_gpair.size() != dh_gpair->Size()) { d_gpair.resize(dh_gpair->Size()); } + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync( d_gpair.data().get(), dh_gpair->ConstDevicePointer(), dh_gpair->Size() * sizeof(GradientPair), cudaMemcpyDeviceToDevice)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync( + d_gpair.data().get(), dh_gpair->ConstDevicePointer(), + dh_gpair->Size() * sizeof(GradientPair), hipMemcpyDeviceToDevice)); +#endif + auto sample = sampler->Sample(dh::ToSpan(d_gpair), dmat); page = sample.page; gpair = sample.gpair; @@ -337,16 +359,30 @@ struct GPUHistMakerDevice { max_active_features = std::max(max_active_features, static_cast(input.feature_set.size())); } +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync( d_node_inputs.data().get(), h_node_inputs.data(), h_node_inputs.size() * sizeof(EvaluateSplitInputs), cudaMemcpyDefault)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync( + d_node_inputs.data().get(), h_node_inputs.data(), + h_node_inputs.size() * sizeof(EvaluateSplitInputs), hipMemcpyDefault)); +#endif this->evaluator_.EvaluateSplits(nidx, max_active_features, dh::ToSpan(d_node_inputs), shared_inputs, dh::ToSpan(entries)); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(), sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync(pinned_candidates_out.data(), + entries.data().get(), sizeof(GPUExpandEntry) * entries.size(), + hipMemcpyDeviceToHost)); +#endif + dh::DefaultStream().Sync(); } @@ -436,9 +472,17 @@ struct GPUHistMakerDevice { } dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), + d_nodes.size() * sizeof(RegTree::Node), + hipMemcpyHostToDevice)); +#endif + auto const& h_split_types = p_tree->GetSplitTypes(); auto const& categories = p_tree->GetSplitCategories(); auto const& categories_segments = p_tree->GetSplitCategoriesPtr(); @@ -508,9 +552,16 @@ struct GPUHistMakerDevice { auto s_position = p_out_position->ConstDeviceSpan(); positions.resize(s_position.size()); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync(positions.data().get(), s_position.data(), s_position.size_bytes(), cudaMemcpyDeviceToDevice, ctx_->CUDACtx()->Stream())); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync(positions.data().get(), s_position.data(), + s_position.size_bytes(), hipMemcpyDeviceToDevice, + ctx_->CUDACtx()->Stream())); +#endif dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) { bst_node_t position = d_out_position[idx]; @@ -525,7 +576,12 @@ struct GPUHistMakerDevice { } CHECK(p_tree); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(ctx_->gpu_id)); +#endif CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id); auto d_position = dh::ToSpan(positions); @@ -533,9 +589,17 @@ struct GPUHistMakerDevice { auto const& h_nodes = p_tree->GetNodes(); dh::caching_device_vector nodes(h_nodes.size()); + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice, ctx_->CUDACtx()->Stream())); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipMemcpyAsync(nodes.data().get(), h_nodes.data(), + h_nodes.size() * sizeof(RegTree::Node), hipMemcpyHostToDevice, + ctx_->CUDACtx()->Stream())); +#endif + auto d_nodes = dh::ToSpan(nodes); dh::LaunchN(d_position.size(), ctx_->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) mutable { @@ -793,7 +857,12 @@ class GPUHistMaker : public TreeUpdater { } ++t_idx; } + +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaGetLastError()); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipGetLastError()); +#endif } catch (const std::exception& e) { LOG(FATAL) << "Exception in gpu_hist: " << e.what() << std::endl; } @@ -813,7 +882,12 @@ class GPUHistMaker : public TreeUpdater { param->max_bin, }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); +#if defined(XGBOOST_USE_CUDA) dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); +#elif defined(XGBOOST_USE_HIP) + dh::safe_cuda(hipSetDevice(ctx_->gpu_id)); +#endif + info_->feature_types.SetDevice(ctx_->gpu_id); maker.reset(new GPUHistMakerDevice( ctx_, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, *param, diff --git a/src/tree/updater_gpu_hist.hip b/src/tree/updater_gpu_hist.hip index e69de29bb..e0f3be6a3 100644 --- a/src/tree/updater_gpu_hist.hip +++ b/src/tree/updater_gpu_hist.hip @@ -0,0 +1,4 @@ + +#if defined(XGBOOST_USE_HIP) +#include "updater_gpu_hist.cu" +#endif