add HIP flags

This commit is contained in:
amdsc21
2023-03-08 01:33:38 +01:00
parent 6b7be96373
commit f5f800c80d
7 changed files with 20 additions and 20 deletions

View File

@@ -47,10 +47,10 @@ class DMatrixProxy : public DMatrix {
dmlc::any batch_;
Context ctx_;
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
void FromCudaColumnar(StringView interface_str);
void FromCudaArray(StringView interface_str);
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
public:
int DeviceIdx() const { return ctx_.gpu_id; }
@@ -58,7 +58,7 @@ class DMatrixProxy : public DMatrix {
void SetCUDAArray(char const* c_interface) {
common::AssertGPUSupport();
CHECK(c_interface);
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
StringView interface_str{c_interface};
Json json_array_interface = Json::Load(interface_str);
if (IsA<Array>(json_array_interface)) {
@@ -66,7 +66,7 @@ class DMatrixProxy : public DMatrix {
} else {
this->FromCudaArray(interface_str);
}
#endif // defined(XGBOOST_USE_CUDA)
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
}
void SetArrayData(char const* c_interface);