fix, to support hip

This commit is contained in:
amdsc21
2023-05-02 17:43:11 +02:00
parent 5446c501af
commit e4538cb13c
2 changed files with 26 additions and 6 deletions

View File

@@ -66,7 +66,7 @@ void IterativeDMatrix::InitFromCUDA(Context const* ctx, BatchParam const& p,
do {
// We use do while here as the first batch is fetched in ctor
// ctx_.gpu_id = proxy->DeviceIdx();
CHECK_LT(ctx_.gpu_id, common::AllVisibleGPUs());
CHECK_LT(ctx->gpu_id, common::AllVisibleGPUs());
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(get_device()));