finish gradient_index.cu

This commit is contained in:
amdsc21 2023-03-10 04:40:33 +01:00
parent 6e2c5be83e
commit 134cbfddbe
3 changed files with 13 additions and 5 deletions

View File

@ -23,7 +23,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
case 2: case 2:
// default per-thread stream // default per-thread stream
default: default:
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream))); dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamSynchronize(reinterpret_cast<hipStream_t>(stream)));
#endif
} }
} }
@ -50,12 +54,12 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
return true; return true;
} }
return true; return true;
#elif defined(XGBOOST_USE_HIP)
return false;
#endif
} else { } else {
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc. // other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
return false; return false;
} }
#elif defined(XGBOOST_USE_HIP)
return false;
#endif
} }
} // namespace xgboost } // namespace xgboost

View File

@ -67,12 +67,12 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
max_numeric_bins_per_feat(max_bin_per_feat), max_numeric_bins_per_feat(max_bin_per_feat),
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
#if !defined(XGBOOST_USE_CUDA) #if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &, GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &,
BatchParam const &) { BatchParam const &) {
common::AssertGPUSupport(); common::AssertGPUSupport();
} }
#endif // defined(XGBOOST_USE_CUDA) #endif // defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
GHistIndexMatrix::~GHistIndexMatrix() = default; GHistIndexMatrix::~GHistIndexMatrix() = default;

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "gradient_index.cu"
#endif