array interface
This commit is contained in:
parent
8e703f3a5a
commit
206f305b65
@ -1,39 +1,64 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2021-2023, XGBoost Contributors
|
* Copyright 2021-2023, XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#include <cstdint> // for int64_t
|
|
||||||
|
|
||||||
|
#include <cstdint> // for int64_t
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
|
#include "../common/device_helpers.cuh" // for DefaultStream, CUDAEvent
|
||||||
#include "array_interface.h"
|
#include "array_interface.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
|
|
||||||
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
void ArrayInterfaceHandler::SyncCudaStream(std::int64_t stream) {
|
||||||
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
if (stream == 1 || stream == 2) {
|
||||||
|
// Default streams, no synchronization needed
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
hipEvent_t event;
|
||||||
|
hipError_t err = hipEventCreate(&event);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
LOG(WARNING) << "Failed to create HIP event: " << hipGetErrorString(err);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hipEventRecord(event, reinterpret_cast<hipStream_t>(stream));
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
LOG(WARNING) << "Failed to record HIP event: " << hipGetErrorString(err);
|
||||||
|
hipEventDestroy(event);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
err = hipStreamWaitEvent(hipStreamPerThread, event, 0);
|
||||||
|
if (err != hipSuccess) {
|
||||||
|
LOG(WARNING) << "Failed to wait for HIP event: " << hipGetErrorString(err);
|
||||||
|
}
|
||||||
|
|
||||||
|
hipEventDestroy(event);
|
||||||
|
#else
|
||||||
switch (stream) {
|
switch (stream) {
|
||||||
case 0:
|
case 0:
|
||||||
/**
|
/**
|
||||||
* disallowed by the `__cuda_array_interface__`. Quote:
|
* disallowed by the *`__cuda_array_interface__`*. Quote:
|
||||||
*
|
*
|
||||||
* This is disallowed as it would be ambiguous between None and the default
|
* This is disallowed as it would be ambiguous between None and the default
|
||||||
* stream, and also between the legacy and per-thread default streams. Any use
|
* stream, and also between the legacy and per-thread default streams. Any use
|
||||||
* case where 0 might be given should either use None, 1, or 2 instead for
|
* case where 0 might be given should either use None, 1, or 2 instead for
|
||||||
* clarity.
|
* clarity.
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_USE_HIP
|
|
||||||
LOG(FATAL) << "Invalid stream ID in array interface: " << stream;
|
LOG(FATAL) << "Invalid stream ID in array interface: " << stream;
|
||||||
#endif
|
case 1: // default legacy stream
|
||||||
case 1:
|
|
||||||
// default legacy stream
|
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2: // default per-thread stream
|
||||||
// default per-thread stream
|
|
||||||
default: {
|
default: {
|
||||||
dh::CUDAEvent e;
|
dh::CUDAEvent e;
|
||||||
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
|
e.Record(dh::CUDAStreamView{reinterpret_cast<cudaStream_t>(stream)});
|
||||||
dh::DefaultStream().Wait(e);
|
dh::DefaultStream().Wait(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
||||||
@ -41,7 +66,38 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_HIP)
|
||||||
|
hipPointerAttribute_t attr;
|
||||||
|
auto err = hipPointerGetAttributes(&attr, ptr);
|
||||||
|
|
||||||
|
// Reset error
|
||||||
|
hipError_t last_error = hipGetLastError();
|
||||||
|
if (last_error != hipSuccess) {
|
||||||
|
LOG(WARNING) << "HIP error after hipPointerGetAttributes: "
|
||||||
|
<< hipGetErrorString(last_error);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (err == hipErrorInvalidValue) {
|
||||||
|
return false;
|
||||||
|
} else if (err == hipSuccess) {
|
||||||
|
// For ROCm 6.2.2, we use the `type` field
|
||||||
|
switch (attr.type) {
|
||||||
|
case hipMemoryTypeUnregistered:
|
||||||
|
case hipMemoryTypeHost:
|
||||||
|
return false;
|
||||||
|
case hipMemoryTypeDevice:
|
||||||
|
case hipMemoryTypeManaged:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
LOG(WARNING) << "Unknown memory type: " << attr.type;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
LOG(WARNING) << "hipPointerGetAttributes failed with error: "
|
||||||
|
<< hipGetErrorString(err);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
#elif defined(XGBOOST_USE_CUDA)
|
||||||
cudaPointerAttributes attr;
|
cudaPointerAttributes attr;
|
||||||
auto err = cudaPointerGetAttributes(&attr, ptr);
|
auto err = cudaPointerGetAttributes(&attr, ptr);
|
||||||
// reset error
|
// reset error
|
||||||
@ -63,34 +119,9 @@ bool ArrayInterfaceHandler::IsCudaPtr(void const* ptr) {
|
|||||||
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
// other errors, `cudaErrorNoDevice`, `cudaErrorInsufficientDriver` etc.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#elif defined(XGBOOST_USE_HIP)
|
|
||||||
hipPointerAttribute_t attr;
|
|
||||||
auto err = hipPointerGetAttributes(&attr, ptr);
|
|
||||||
// reset error
|
|
||||||
CHECK_EQ(err, hipGetLastError());
|
|
||||||
if (err == hipErrorInvalidValue) {
|
|
||||||
return false;
|
|
||||||
} else if (err == hipSuccess) {
|
|
||||||
#if HIP_VERSION_MAJOR < 6
|
|
||||||
switch (attr.memoryType) {
|
|
||||||
case hipMemoryTypeHost:
|
|
||||||
return false;
|
|
||||||
default:
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
switch (attr.type) {
|
return false;
|
||||||
case hipMemoryTypeUnregistered:
|
|
||||||
case hipMemoryTypeHost:
|
|
||||||
return false;
|
|
||||||
default:
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
Loading…
x
Reference in New Issue
Block a user