array interface

This commit is contained in:
Hendrik Groove 2024-10-20 01:28:40 +02:00
parent 8e703f3a5a
commit 206f305b65

View File

@ -1,39 +1,64 @@
/**
* Copyright 2021-2023, XGBoost Contributors
/**
* Copyright 2021-2023, XGBoost Contributors
*/
#include <cstdint> // for int64_t
#include <cstdint> // for int64_t
#include "../common/common.h"
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
#include "array_interface.h"
#include "xgboost/logging.h"
namespace xgboost {
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
#if defined(XGBOOST_USE_HIP)
if (stream == 1 || stream == 2) {
// Default streams, no synchronization needed
return;
}
hipEvent_t event;
hipError_t err = hipEventCreate(&event);
if (err != hipSuccess) {
LOG(WARNING) << "Failed to create HIP event: " << hipGetErrorString(err);
return;
}
err = hipEventRecord(event, reinterpret_cast<hipStream_t>(stream));
if (err != hipSuccess) {
LOG(WARNING) << "Failed to record HIP event: " << hipGetErrorString(err);
hipEventDestroy(event);
return;
}
err = hipStreamWaitEvent(hipStreamPerThread, event, 0);
if (err != hipSuccess) {
LOG(WARNING) << "Failed to wait for HIP event: " << hipGetErrorString(err);
}
hipEventDestroy(event);
#else
switch (stream) {
case 0:
/**
* disallowed by the `__cuda_array_interface__`. Quote:
* disallowed by the *`__cuda_array_interface__`*. Quote:
*
* This is disallowed as it would be ambiguous between None and the default
* stream, and also between the legacy and per-thread default streams. Any use
* case where 0 might be given should either use None, 1, or 2 instead for
* clarity.
* This is disallowed as it would be ambiguous between None and the default
* stream, and also between the legacy and per-thread default streams. Any use
* case where 0 might be given should either use None, 1, or 2 instead for
* clarity.
*/
#ifndef XGBOOST_USE_HIP
LOG(FATAL) << "Invalid stream ID in array interface: " << stream;
#endif
case 1:
// default legacy stream
case 1: // default legacy stream
break;
case 2:
// default per-thread stream
case 2: // default per-thread stream
default: {
dh::CUDAEvent e;
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
dh::DefaultStream().Wait(e);
}
}
#endif
}
bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
@ -41,7 +66,38 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
return false;
}
#if defined(XGBOOST_USE_CUDA)
#if defined(XGBOOST_USE_HIP)
hipPointerAttribute_t attr;
auto err = hipPointerGetAttributes(&attr, ptr);
// Reset error
hipError_t last_error = hipGetLastError();
if (last_error != hipSuccess) {
LOG(WARNING) << "HIP error after hipPointerGetAttributes: "
<< hipGetErrorString(last_error);
}
if (err == hipErrorInvalidValue) {
return false;
} else if (err == hipSuccess) {
// For ROCm 6.2.2, we use the `type` field
switch (attr.type) {
case hipMemoryTypeUnregistered:
case hipMemoryTypeHost:
return false;
case hipMemoryTypeDevice:
case hipMemoryTypeManaged:
return true;
default:
LOG(WARNING) << "Unknown memory type: " << attr.type;
return false;
}
} else {
LOG(WARNING) << "hipPointerGetAttributes failed with error: "
<< hipGetErrorString(err);
return false;
}
#elif defined(XGBOOST_USE_CUDA)
cudaPointerAttributes attr;
auto err = cudaPointerGetAttributes(&attr, ptr);
// reset error
@ -63,34 +119,9 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
return false;
}
#elif defined(XGBOOST_USE_HIP)
hipPointerAttribute_t attr;
auto err = hipPointerGetAttributes(&attr, ptr);
// reset error
CHECK_EQ(err, hipGetLastError());
if (err == hipErrorInvalidValue) {
return false;
} else if (err == hipSuccess) {
#if HIP_VERSION_MAJOR < 6
switch (attr.memoryType) {
case hipMemoryTypeHost:
return false;
default:
return true;
}
#else
switch (attr.type) {
case hipMemoryTypeUnregistered:
case hipMemoryTypeHost:
return false;
default:
return true;
}
#endif
return true;
} else {
return false;
}
return false;
#endif
}
} // namespace xgboost
} // namespace xgboost