finish gradient_index.cu
This commit is contained in:
parent
6e2c5be83e
commit
134cbfddbe
@ -23,7 +23,11 @@ void ArrayInterfaceHandler::SyncCudaStream(int64_t stream) {
|
|||||||
case 2:
|
case 2:
|
||||||
// default per-thread stream
|
// default per-thread stream
|
||||||
default:
|
default:
|
||||||
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
|
dh::safe_cuda(cudaStreamSynchronize(reinterpret_cast<cudaStream_t>(stream)));
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
dh::safe_cuda(hipStreamSynchronize(reinterpret_cast<hipStream_t>(stream)));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -50,12 +54,12 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
#elif defined(XGBOOST_USE_HIP)
|
|
||||||
return false;
|
|
||||||
#endif
|
|
||||||
} else {
|
} else {
|
||||||
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
#elif defined(XGBOOST_USE_HIP)
|
||||||
|
return false;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -67,12 +67,12 @@ GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &
|
|||||||
max_numeric_bins_per_feat(max_bin_per_feat),
|
max_numeric_bins_per_feat(max_bin_per_feat),
|
||||||
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {}
|
||||||
|
|
||||||
#if !defined(XGBOOST_USE_CUDA)
|
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||||
GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &,
|
GHistIndexMatrix::GHistIndexMatrix(Context const *, MetaInfo const &, EllpackPage const &,
|
||||||
BatchParam const &) {
|
BatchParam const &) {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
|
||||||
|
|
||||||
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
GHistIndexMatrix::~GHistIndexMatrix() = default;
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,4 @@
|
|||||||
|
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
#include "gradient_index.cu"
|
||||||
|
#endif
|
||||||
Loading…
x
Reference in New Issue
Block a user