add HIP flags
This commit is contained in:
@@ -47,10 +47,10 @@ class DMatrixProxy : public DMatrix {
|
||||
dmlc::any batch_;
|
||||
Context ctx_;
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
void FromCudaColumnar(StringView interface_str);
|
||||
void FromCudaArray(StringView interface_str);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
|
||||
public:
|
||||
int DeviceIdx() const { return ctx_.gpu_id; }
|
||||
@@ -58,7 +58,7 @@ class DMatrixProxy : public DMatrix {
|
||||
void SetCUDAArray(char const* c_interface) {
|
||||
common::AssertGPUSupport();
|
||||
CHECK(c_interface);
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
StringView interface_str{c_interface};
|
||||
Json json_array_interface = Json::Load(interface_str);
|
||||
if (IsA<Array>(json_array_interface)) {
|
||||
@@ -66,7 +66,7 @@ class DMatrixProxy : public DMatrix {
|
||||
} else {
|
||||
this->FromCudaArray(interface_str);
|
||||
}
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
|
||||
}
|
||||
|
||||
void SetArrayData(char const* c_interface);
|
||||
|
||||
Reference in New Issue
Block a user