enable ROCm on latest XGBoost

This commit is contained in:
Hui Liu
2023-10-23 11:07:08 -07:00
328 changed files with 8028 additions and 3642 deletions

View File

@@ -148,19 +148,18 @@ class ElementWiseSurvivalMetricsReduction {
const HostDeviceVector<bst_float>& preds) {
PackedReduceResult result;
if (ctx.gpu_id < 0) {
if (ctx.IsCPU()) {
result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound,
preds, ctx.Threads());
}
#if defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)
else { // NOLINT
preds.SetDevice(ctx.gpu_id);
labels_lower_bound.SetDevice(ctx.gpu_id);
labels_upper_bound.SetDevice(ctx.gpu_id);
weights.SetDevice(ctx.gpu_id);
dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
preds.SetDevice(ctx.Device());
labels_lower_bound.SetDevice(ctx.Device());
labels_upper_bound.SetDevice(ctx.Device());
weights.SetDevice(ctx.Device());
dh::safe_cuda(cudaSetDevice(ctx.Ordinal()));
result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds);
}
#endif // defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)