finish fit_stump.cu

This commit is contained in:
amdsc21 2023-03-10 00:46:29 +01:00
parent 1530c03f7d
commit 1c58ff61d1
4 changed files with 19 additions and 3 deletions

View File

@ -15,7 +15,7 @@
#include "constraints.cuh"
#include "param.h"
#if defined(XGBOOST_USE_hip.CUDA)
#if defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../common/device_helpers.hip.h"

View File

@ -56,12 +56,12 @@ namespace cuda_impl {
void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpair,
linalg::VectorView<float> out);
#if !defined(XGBOOST_USE_CUDA)
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
inline void FitStump(Context const*, linalg::TensorView<GradientPair const, 2>,
linalg::VectorView<float>) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
#endif // !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_C
} // namespace cuda_impl
void FitStump(Context const* ctx, HostDeviceVector<GradientPair> const& gpair,

View File

@ -12,7 +12,13 @@
#include <cstddef> // std::size_t
#include "../collective/device_communicator.cuh" // DeviceCommunicator
#if defined(XGBOOST_USE_CUDA)
#include "../common/device_helpers.cuh" // dh::MakeTransformIterator
#elif defined(XGBOOST_USE_HIP)
#include "../common/device_helpers.hip.h" // dh::MakeTransformIterator
#endif
#include "fit_stump.h"
#include "xgboost/base.h" // GradientPairPrecise, GradientPair, XGBOOST_DEVICE
#include "xgboost/context.h" // Context
@ -45,7 +51,13 @@ void FitStump(Context const* ctx, linalg::TensorView<GradientPair const, 2> gpai
CHECK(d_sum.CContiguous());
dh::XGBCachingDeviceAllocator<char> alloc;
#if defined(XGBOOST_USE_CUDA)
auto policy = thrust::cuda::par(alloc);
#elif defined(XGBOOST_USE_HIP)
auto policy = thrust::hip::par(alloc);
#endif
thrust::reduce_by_key(policy, key_it, key_it + gpair.Size(), grad_it,
thrust::make_discard_iterator(), dh::tbegin(d_sum.Values()));

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "fit_stump.cu"
#endif