add HIP flags

This commit is contained in:
amdsc21 2023-03-08 01:22:25 +01:00
parent 75712b9c3c
commit 6b7be96373
2 changed files with 5 additions and 5 deletions

View File

@ -98,7 +98,7 @@ class Communicator {
/** @brief Get the communicator instance. */
static Communicator *Get() { return communicator_.get(); }
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
/**
* @brief Get the device communicator.
*

View File

@ -271,9 +271,9 @@ class GBTree : public GradientBooster {
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
std::vector<Predictor const *> predictors{
cpu_predictor_.get(),
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
gpu_predictor_.get()
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
};
StringView msg{"Unsupported data type for inplace predict."};
if (tparam_.predictor == PredictorType::kAuto) {
@ -441,9 +441,9 @@ class GBTree : public GradientBooster {
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
// Predictors
std::unique_ptr<Predictor> cpu_predictor_;
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
std::unique_ptr<Predictor> gpu_predictor_;
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
#if defined(XGBOOST_USE_ONEAPI)
std::unique_ptr<Predictor> oneapi_predictor_;
#endif // defined(XGBOOST_USE_ONEAPI)