[coll] Pass context to various functions. (#9772)
* [coll] Pass context to various functions. In the future, the `Context` object would be required for collective operations, this PR passes the context object to some required functions to prepare for swapping out the implementation.
This commit is contained in:
@@ -199,9 +199,9 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
|
||||
});
|
||||
}
|
||||
|
||||
double ScaleClasses(common::Span<double> results, common::Span<double> local_area,
|
||||
common::Span<double> tp, common::Span<double> auc, size_t n_classes) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
double ScaleClasses(Context const *ctx, common::Span<double> results,
|
||||
common::Span<double> local_area, common::Span<double> tp,
|
||||
common::Span<double> auc, size_t n_classes) {
|
||||
if (collective::IsDistributed()) {
|
||||
int32_t device = dh::CurrentDevice();
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), device);
|
||||
@@ -218,8 +218,8 @@ double ScaleClasses(common::Span<double> results, common::Span<double> local_are
|
||||
double tp_sum;
|
||||
double auc_sum;
|
||||
thrust::tie(auc_sum, tp_sum) =
|
||||
thrust::reduce(thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
|
||||
Pair{0.0, 0.0}, PairPlus<double, double>{});
|
||||
thrust::reduce(ctx->CUDACtx()->CTP(), reduce_in, reduce_in + n_classes, Pair{0.0, 0.0},
|
||||
PairPlus<double, double>{});
|
||||
if (tp_sum != 0 && !std::isnan(auc_sum)) {
|
||||
auc_sum /= tp_sum;
|
||||
} else {
|
||||
@@ -309,10 +309,10 @@ void SegmentedReduceAUC(common::Span<size_t const> d_unique_idx,
|
||||
* up each class in all kernels.
|
||||
*/
|
||||
template <bool scale, typename Fn>
|
||||
double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
double GPUMultiClassAUCOVR(Context const *ctx, MetaInfo const &info,
|
||||
common::Span<uint32_t> d_class_ptr, size_t n_classes,
|
||||
std::shared_ptr<DeviceAUCCache> cache, Fn area_fn) {
|
||||
dh::safe_cuda(cudaSetDevice(device.ordinal));
|
||||
dh::safe_cuda(cudaSetDevice(ctx->Ordinal()));
|
||||
/**
|
||||
* Sorted idx
|
||||
*/
|
||||
@@ -320,7 +320,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
// Index is sorted within class.
|
||||
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
|
||||
|
||||
auto labels = info.labels.View(device);
|
||||
auto labels = info.labels.View(ctx->Device());
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
size_t n_samples = labels.Shape(0);
|
||||
@@ -328,12 +328,11 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
if (n_samples == 0) {
|
||||
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
|
||||
auto d_results = dh::ToSpan(resutls);
|
||||
dh::LaunchN(n_classes * 4,
|
||||
[=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
|
||||
dh::LaunchN(n_classes * 4, [=] XGBOOST_DEVICE(size_t i) { d_results[i] = 0.0f; });
|
||||
auto local_area = d_results.subspan(0, n_classes);
|
||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -437,7 +436,7 @@ double GPUMultiClassAUCOVR(MetaInfo const &info, DeviceOrd device,
|
||||
tp[c] = 1.0f;
|
||||
}
|
||||
});
|
||||
return ScaleClasses(d_results, local_area, tp, auc, n_classes);
|
||||
return ScaleClasses(ctx, d_results, local_area, tp, auc, n_classes);
|
||||
}
|
||||
|
||||
void MultiClassSortedIdx(Context const *ctx, common::Span<float const> predts,
|
||||
@@ -472,8 +471,7 @@ double GPUMultiClassROCAUC(Context const *ctx, common::Span<float const> predts,
|
||||
size_t /*class_id*/) {
|
||||
return TrapezoidArea(fp_prev, fp, tp_prev, tp);
|
||||
};
|
||||
return GPUMultiClassAUCOVR<true>(info, ctx->Device(), dh::ToSpan(class_ptr), n_classes, cache,
|
||||
fn);
|
||||
return GPUMultiClassAUCOVR<true>(ctx, info, dh::ToSpan(class_ptr), n_classes, cache, fn);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@@ -697,7 +695,7 @@ double GPUMultiClassPRAUC(Context const *ctx, common::Span<float const> predts,
|
||||
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
|
||||
d_totals[class_id].first);
|
||||
};
|
||||
return GPUMultiClassAUCOVR<false>(info, ctx->Device(), d_class_ptr, n_classes, cache, fn);
|
||||
return GPUMultiClassAUCOVR<false>(ctx, info, d_class_ptr, n_classes, cache, fn);
|
||||
}
|
||||
|
||||
template <typename Fn>
|
||||
|
||||
Reference in New Issue
Block a user