diff --git a/src/metric/auc.cc b/src/metric/auc.cc index 9184223da..25e671e93 100644 --- a/src/metric/auc.cc +++ b/src/metric/auc.cc @@ -87,8 +87,7 @@ std::tuple BinaryAUC(std::vector const &predts, * - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class * Machine Learning Models */ -float MultiClassOVR(std::vector const& predts, MetaInfo const& info) { - auto n_classes = predts.size() / info.labels_.Size(); +float MultiClassOVR(std::vector const& predts, MetaInfo const& info, size_t n_classes) { CHECK_NE(n_classes, 0); auto const& labels = info.labels_.ConstHostVector(); @@ -230,6 +229,10 @@ class EvalAUC : public Metric { info.labels_.SetDevice(tparam_->gpu_id); info.weights_.SetDevice(tparam_->gpu_id); } + // We use the global size to handle empty dataset. + std::array meta{info.labels_.Size(), preds.Size()}; + rabit::Allreduce(meta.data(), meta.size()); + if (!info.group_ptr_.empty()) { /** * learning to rank @@ -261,16 +264,17 @@ class EvalAUC : public Metric { CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups << ", valid groups: " << valid_groups; } - } else if (info.labels_.Size() != preds.Size() && - preds.Size() % info.labels_.Size() == 0) { + } else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) { /** * multi class */ + size_t n_classes = meta[1] / meta[0]; + CHECK_NE(n_classes, 0); if (tparam_->gpu_id == GenericParameter::kCpuId) { - auc = MultiClassOVR(preds.ConstHostVector(), info); + auc = MultiClassOVR(preds.ConstHostVector(), info, n_classes); } else { auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id, - &this->d_cache_); + &this->d_cache_, n_classes); } } else { /** @@ -323,7 +327,8 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, } float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr* cache) { + int32_t device, std::shared_ptr* cache, + size_t n_classes) { common::AssertGPUSupport(); return 0; } diff --git a/src/metric/auc.cu b/src/metric/auc.cu index ea837e413..a550f14e7 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -61,10 +61,12 @@ struct DeviceAUCCache { neg_pos.resize(sorted_idx.size()); if (is_multi) { predts_t.resize(sorted_idx.size()); - reducer.reset(new dh::AllReducer); - reducer->Init(rabit::GetRank()); } } + if (is_multi && !reducer) { + reducer.reset(new dh::AllReducer); + reducer->Init(device); + } } }; @@ -197,12 +199,48 @@ XGBOOST_DEVICE size_t LastOf(size_t group, common::Span indptr) { return indptr[group + 1] - 1; } + +float ScaleClasses(common::Span results, common::Span local_area, + common::Span fp, common::Span tp, + common::Span auc, std::shared_ptr cache, + size_t n_classes) { + dh::XGBDeviceAllocator alloc; + if (rabit::IsDistributed()) { + CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice()); + cache->reducer->AllReduceSum(results.data(), results.data(), results.size()); + } + auto reduce_in = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0), [=] __device__(size_t i) { + if (local_area[i] > 0) { + return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]); + } + return thrust::make_pair(std::numeric_limits::quiet_NaN(), 0.0f); + }); + + float tp_sum; + float auc_sum; + thrust::tie(auc_sum, tp_sum) = thrust::reduce( + thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, + thrust::make_pair(0.0f, 0.0f), + [=] __device__(auto const &l, auto const &r) { + return thrust::make_pair(l.first + r.first, l.second + r.second); + }); + if (tp_sum != 0 && !std::isnan(auc_sum)) { + auc_sum /= tp_sum; + } else { + return std::numeric_limits::quiet_NaN(); + } + return auc_sum; +} + /** * MultiClass implementation is similar to binary classification, except we need to split * up each class in all kernels. */ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr* p_cache) { + int32_t device, std::shared_ptr* p_cache, + size_t n_classes) { + dh::safe_cuda(cudaSetDevice(device)); auto& cache = *p_cache; if (!cache) { cache.reset(new DeviceAUCCache); @@ -213,8 +251,19 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info auto weights = info.weights_.ConstDeviceSpan(); size_t n_samples = labels.size(); - size_t n_classes = predts.size() / labels.size(); - CHECK_NE(n_classes, 0); + + if (n_samples == 0) { + dh::TemporaryArray resutls(n_classes * 4, 0.0f); + auto d_results = dh::ToSpan(resutls); + dh::LaunchN(device, n_classes * 4, [=]__device__(size_t i) { + d_results[i] = 0.0f; + }); + auto local_area = d_results.subspan(0, n_classes); + auto fp = d_results.subspan(n_classes, n_classes); + auto tp = d_results.subspan(2 * n_classes, n_classes); + auto auc = d_results.subspan(3 * n_classes, n_classes); + return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes); + } /** * Create sorted index for each class @@ -377,32 +426,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info tp[c] = last.second; local_area[c] = last.first * last.second; }); - if (rabit::IsDistributed()) { - cache->reducer->AllReduceSum(resutls.data().get(), resutls.data().get(), - resutls.size()); - } - auto reduce_in = dh::MakeTransformIterator>( - thrust::make_counting_iterator(0), [=] __device__(size_t i) { - if (local_area[i] > 0) { - return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]); - } - return thrust::make_pair(std::numeric_limits::quiet_NaN(), 0.0f); - }); - - float tp_sum; - float auc_sum; - thrust::tie(auc_sum, tp_sum) = thrust::reduce( - thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes, - thrust::make_pair(0.0f, 0.0f), - [=] __device__(auto const &l, auto const &r) { - return thrust::make_pair(l.first + r.first, l.second + r.second); - }); - if (tp_sum != 0 && !std::isnan(auc_sum)) { - auc_sum /= tp_sum; - } else { - return std::numeric_limits::quiet_NaN(); - } - return auc_sum; + return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes); } namespace { diff --git a/src/metric/auc.h b/src/metric/auc.h index cb443f238..d549ac426 100644 --- a/src/metric/auc.h +++ b/src/metric/auc.h @@ -26,7 +26,8 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, int32_t device, std::shared_ptr *p_cache); float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info, - int32_t device, std::shared_ptr* cache); + int32_t device, std::shared_ptr* cache, + size_t n_classes); std::pair GPURankingAUC(common::Span predts, MetaInfo const &info, diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index cfdd9db12..87618fbcf 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -277,7 +277,7 @@ class TestDistributedGPU: X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_)) y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_)) w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_)) - run_dask_classifier(X, y, w, model, client, 10) + run_dask_classifier(X, y, w, model, "gpu_hist", client, 10) @pytest.mark.skipif(**tm.no_dask()) @pytest.mark.skipif(**tm.no_dask_cuda()) diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index eae0f54b1..cb45c0dd9 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -317,6 +317,7 @@ def run_dask_classifier( y: xgb.dask._DaskCollection, w: xgb.dask._DaskCollection, model: str, + tree_method: Optional[str], client: "Client", n_classes, ) -> None: @@ -324,11 +325,11 @@ def run_dask_classifier( if model == "boosting": classifier = xgb.dask.DaskXGBClassifier( - verbosity=1, n_estimators=2, eval_metric=metric + verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method ) else: classifier = xgb.dask.DaskXGBRFClassifier( - verbosity=1, n_estimators=2, eval_metric=metric + verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method ) assert classifier._estimator_type == "classifier" @@ -397,12 +398,12 @@ def run_dask_classifier( def test_dask_classifier(model: str, client: "Client") -> None: X, y, w = generate_array(with_weights=True) y = (y * 10).astype(np.int32) - run_dask_classifier(X, y, w, model, client, 10) + run_dask_classifier(X, y, w, model, None, client, 10) y_bin = y.copy() y_bin[y > 5] = 1.0 y_bin[y <= 5] = 0.0 - run_dask_classifier(X, y_bin, w, model, client, 2) + run_dask_classifier(X, y_bin, w, model, None, client, 2) @pytest.mark.skipif(**tm.no_sklearn()) @@ -568,22 +569,26 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) -> # multiclass X_, y_ = make_classification( n_samples=n_samples, - n_classes=10, + n_classes=n_workers, n_informative=n_features, n_redundant=0, n_repeated=0 ) + for i in range(y_.shape[0]): + y_[i] = i % n_workers X = dd.from_array(X_, chunksize=10) y = dd.from_array(y_, chunksize=10) n_samples = n_workers - 1 valid_X_, valid_y_ = make_classification( n_samples=n_samples, - n_classes=10, + n_classes=n_workers, n_informative=n_features, n_redundant=0, n_repeated=0 ) + for i in range(valid_y_.shape[0]): + valid_y_[i] = i % n_workers valid_X = dd.from_array(valid_X_, chunksize=n_samples) valid_y = dd.from_array(valid_y_, chunksize=n_samples) @@ -594,9 +599,9 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) -> def test_empty_dmatrix_auc() -> None: - with LocalCluster(n_workers=2) as cluster: + with LocalCluster(n_workers=8) as cluster: with Client(cluster) as client: - run_empty_dmatrix_auc(client, "hist", 2) + run_empty_dmatrix_auc(client, "hist", 8) def run_auc(client: "Client", tree_method: str) -> None: