[coll] Pass context to various functions. (#9772)

* [coll] Pass context to various functions.

In the future, the `Context` object would be required for collective operations, this PR
passes the context object to some required functions to prepare for swapping out the
implementation.
This commit is contained in:
Jiaming Yuan
2023-11-08 09:54:05 +08:00
committed by GitHub
parent 6c0a190f6d
commit 06bdc15e9b
45 changed files with 275 additions and 255 deletions

View File

@@ -199,9 +199,9 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
});
}
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
dh::XGBDeviceAllocator<char> alloc;
double ScaleClasses(Context const *ctx, common::Span<double> results,
common::Span<double> local_area, common::Span<double> tp,
common::Span<double> auc, size_t n_classes) {
if (collective::IsDistributed()) {
int32_t device = dh::CurrentDevice();
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
@@ -218,8 +218,8 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
double tp_sum;
double auc_sum;
thrust::tie(auc_sum, tp_sum) =
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
Pair{0.0, 0.0}, PairPlus<double, double>{});
thrust::reduce(ctx->CUDACtx()->CTP(), reduce_in, reduce_in + n_classes, Pair{0.0, 0.0},
PairPlus<double, double>{});
if (tp_sum != 0 && !std::isnan(auc_sum)) {
auc_sum /= tp_sum;
} else {
@@ -309,10 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
* up each class in all kernels.
*/
template <bool scale, typename Fn>
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
common::Span<uint32_t> d_class_ptr, size_t n_classes,
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
dh::safe_cuda(cudaSetDevice(device.ordinal));
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
/**
* Sorted idx
*/
@@ -320,7 +320,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
// Index is sorted within class.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels.View(device);
auto labels = info.labels.View(ctx->Device());
auto weights = info.weights_.ConstDeviceSpan();
size_t n_samples = labels.Shape(0);
@@ -328,12 +328,11 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
if (n_samples == 0) {
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
auto d_results = dh::ToSpan(resutls);
dh::LaunchN(n_classes * 4,
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
dh::LaunchN(n_classes * 4, [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
auto local_area = d_results.subspan(0, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
}
/**
@@ -437,7 +436,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
tp[c] = 1.0f;
}
});
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
}
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
@@ -472,8 +471,7 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
size_t /*class_id*/) {
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
};
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
fn);
return GPUMultiClassAUCOVR<true>(ctx, info, dh::ToSpan(class_ptr), n_classes, cache, fn);
}
namespace {
@@ -697,7 +695,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[class_id].first);
};
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
return GPUMultiClassAUCOVR<false>(ctx, info, d_class_ptr, n_classes, cache, fn);
}
template <typename Fn>