add HIP flags
This commit is contained in:
parent
75712b9c3c
commit
6b7be96373
@ -98,7 +98,7 @@ class Communicator {
|
|||||||
/** @brief Get the communicator instance. */
|
/** @brief Get the communicator instance. */
|
||||||
static Communicator *Get() { return communicator_.get(); }
|
static Communicator *Get() { return communicator_.get(); }
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||||
/**
|
/**
|
||||||
* @brief Get the device communicator.
|
* @brief Get the device communicator.
|
||||||
*
|
*
|
||||||
|
|||||||
@ -271,9 +271,9 @@ class GBTree : public GradientBooster {
|
|||||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||||
std::vector<Predictor const *> predictors{
|
std::vector<Predictor const *> predictors{
|
||||||
cpu_predictor_.get(),
|
cpu_predictor_.get(),
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||||
gpu_predictor_.get()
|
gpu_predictor_.get()
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||||
};
|
};
|
||||||
StringView msg{"Unsupported data type for inplace predict."};
|
StringView msg{"Unsupported data type for inplace predict."};
|
||||||
if (tparam_.predictor == PredictorType::kAuto) {
|
if (tparam_.predictor == PredictorType::kAuto) {
|
||||||
@ -441,9 +441,9 @@ class GBTree : public GradientBooster {
|
|||||||
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
|
std::vector<std::unique_ptr<TreeUpdater>> updaters_;
|
||||||
// Predictors
|
// Predictors
|
||||||
std::unique_ptr<Predictor> cpu_predictor_;
|
std::unique_ptr<Predictor> cpu_predictor_;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||||
std::unique_ptr<Predictor> gpu_predictor_;
|
std::unique_ptr<Predictor> gpu_predictor_;
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||||
#if defined(XGBOOST_USE_ONEAPI)
|
#if defined(XGBOOST_USE_ONEAPI)
|
||||||
std::unique_ptr<Predictor> oneapi_predictor_;
|
std::unique_ptr<Predictor> oneapi_predictor_;
|
||||||
#endif // defined(XGBOOST_USE_ONEAPI)
|
#endif // defined(XGBOOST_USE_ONEAPI)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user