parent
651c4ac03b
commit
8147d78b12
@ -87,8 +87,7 @@ std::tuple<float, float, float> BinaryAUC(std::vector<float> const &predts,
|
||||
* - Kleiman, Ross and Page, David. $AUC_{\mu}$: A Performance Metric for Multi-Class
|
||||
* Machine Learning Models
|
||||
*/
|
||||
float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info) {
|
||||
auto n_classes = predts.size() / info.labels_.Size();
|
||||
float MultiClassOVR(std::vector<float> const& predts, MetaInfo const& info, size_t n_classes) {
|
||||
CHECK_NE(n_classes, 0);
|
||||
auto const& labels = info.labels_.ConstHostVector();
|
||||
|
||||
@ -230,6 +229,10 @@ class EvalAUC : public Metric {
|
||||
info.labels_.SetDevice(tparam_->gpu_id);
|
||||
info.weights_.SetDevice(tparam_->gpu_id);
|
||||
}
|
||||
// We use the global size to handle empty dataset.
|
||||
std::array<size_t, 2> meta{info.labels_.Size(), preds.Size()};
|
||||
rabit::Allreduce<rabit::op::Max>(meta.data(), meta.size());
|
||||
|
||||
if (!info.group_ptr_.empty()) {
|
||||
/**
|
||||
* learning to rank
|
||||
@ -261,16 +264,17 @@ class EvalAUC : public Metric {
|
||||
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
|
||||
<< ", valid groups: " << valid_groups;
|
||||
}
|
||||
} else if (info.labels_.Size() != preds.Size() &&
|
||||
preds.Size() % info.labels_.Size() == 0) {
|
||||
} else if (meta[0] != meta[1] && meta[1] % meta[0] == 0) {
|
||||
/**
|
||||
* multi class
|
||||
*/
|
||||
size_t n_classes = meta[1] / meta[0];
|
||||
CHECK_NE(n_classes, 0);
|
||||
if (tparam_->gpu_id == GenericParameter::kCpuId) {
|
||||
auc = MultiClassOVR(preds.ConstHostVector(), info);
|
||||
auc = MultiClassOVR(preds.ConstHostVector(), info, n_classes);
|
||||
} else {
|
||||
auc = GPUMultiClassAUCOVR(preds.ConstDeviceSpan(), info, tparam_->gpu_id,
|
||||
&this->d_cache_);
|
||||
&this->d_cache_, n_classes);
|
||||
}
|
||||
} else {
|
||||
/**
|
||||
@ -323,7 +327,8 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
}
|
||||
|
||||
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache>* cache) {
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache>* cache,
|
||||
size_t n_classes) {
|
||||
common::AssertGPUSupport();
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -61,10 +61,12 @@ struct DeviceAUCCache {
|
||||
neg_pos.resize(sorted_idx.size());
|
||||
if (is_multi) {
|
||||
predts_t.resize(sorted_idx.size());
|
||||
reducer.reset(new dh::AllReducer);
|
||||
reducer->Init(rabit::GetRank());
|
||||
}
|
||||
}
|
||||
if (is_multi && !reducer) {
|
||||
reducer.reset(new dh::AllReducer);
|
||||
reducer->Init(device);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -197,12 +199,48 @@ XGBOOST_DEVICE size_t LastOf(size_t group, common::Span<Idx> indptr) {
|
||||
return indptr[group + 1] - 1;
|
||||
}
|
||||
|
||||
|
||||
float ScaleClasses(common::Span<float> results, common::Span<float> local_area,
|
||||
common::Span<float> fp, common::Span<float> tp,
|
||||
common::Span<float> auc, std::shared_ptr<DeviceAUCCache> cache,
|
||||
size_t n_classes) {
|
||||
dh::XGBDeviceAllocator<char> alloc;
|
||||
if (rabit::IsDistributed()) {
|
||||
CHECK_EQ(dh::CudaGetPointerDevice(results.data()), dh::CurrentDevice());
|
||||
cache->reducer->AllReduceSum(results.data(), results.data(), results.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
|
||||
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
|
||||
if (local_area[i] > 0) {
|
||||
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
|
||||
}
|
||||
return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f);
|
||||
});
|
||||
|
||||
float tp_sum;
|
||||
float auc_sum;
|
||||
thrust::tie(auc_sum, tp_sum) = thrust::reduce(
|
||||
thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
|
||||
thrust::make_pair(0.0f, 0.0f),
|
||||
[=] __device__(auto const &l, auto const &r) {
|
||||
return thrust::make_pair(l.first + r.first, l.second + r.second);
|
||||
});
|
||||
if (tp_sum != 0 && !std::isnan(auc_sum)) {
|
||||
auc_sum /= tp_sum;
|
||||
} else {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return auc_sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* MultiClass implementation is similar to binary classification, except we need to split
|
||||
* up each class in all kernels.
|
||||
*/
|
||||
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache) {
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache>* p_cache,
|
||||
size_t n_classes) {
|
||||
dh::safe_cuda(cudaSetDevice(device));
|
||||
auto& cache = *p_cache;
|
||||
if (!cache) {
|
||||
cache.reset(new DeviceAUCCache);
|
||||
@ -213,8 +251,19 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
|
||||
auto weights = info.weights_.ConstDeviceSpan();
|
||||
|
||||
size_t n_samples = labels.size();
|
||||
size_t n_classes = predts.size() / labels.size();
|
||||
CHECK_NE(n_classes, 0);
|
||||
|
||||
if (n_samples == 0) {
|
||||
dh::TemporaryArray<float> resutls(n_classes * 4, 0.0f);
|
||||
auto d_results = dh::ToSpan(resutls);
|
||||
dh::LaunchN(device, n_classes * 4, [=]__device__(size_t i) {
|
||||
d_results[i] = 0.0f;
|
||||
});
|
||||
auto local_area = d_results.subspan(0, n_classes);
|
||||
auto fp = d_results.subspan(n_classes, n_classes);
|
||||
auto tp = d_results.subspan(2 * n_classes, n_classes);
|
||||
auto auc = d_results.subspan(3 * n_classes, n_classes);
|
||||
return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create sorted index for each class
|
||||
@ -377,32 +426,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
|
||||
tp[c] = last.second;
|
||||
local_area[c] = last.first * last.second;
|
||||
});
|
||||
if (rabit::IsDistributed()) {
|
||||
cache->reducer->AllReduceSum(resutls.data().get(), resutls.data().get(),
|
||||
resutls.size());
|
||||
}
|
||||
auto reduce_in = dh::MakeTransformIterator<thrust::pair<float, float>>(
|
||||
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
|
||||
if (local_area[i] > 0) {
|
||||
return thrust::make_pair(auc[i] / local_area[i] * tp[i], tp[i]);
|
||||
}
|
||||
return thrust::make_pair(std::numeric_limits<float>::quiet_NaN(), 0.0f);
|
||||
});
|
||||
|
||||
float tp_sum;
|
||||
float auc_sum;
|
||||
thrust::tie(auc_sum, tp_sum) = thrust::reduce(
|
||||
thrust::cuda::par(alloc), reduce_in, reduce_in + n_classes,
|
||||
thrust::make_pair(0.0f, 0.0f),
|
||||
[=] __device__(auto const &l, auto const &r) {
|
||||
return thrust::make_pair(l.first + r.first, l.second + r.second);
|
||||
});
|
||||
if (tp_sum != 0 && !std::isnan(auc_sum)) {
|
||||
auc_sum /= tp_sum;
|
||||
} else {
|
||||
return std::numeric_limits<float>::quiet_NaN();
|
||||
}
|
||||
return auc_sum;
|
||||
return ScaleClasses(d_results, local_area, fp, tp, auc, cache, n_classes);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
@ -26,7 +26,8 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache> *p_cache);
|
||||
|
||||
float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info,
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache>* cache);
|
||||
int32_t device, std::shared_ptr<DeviceAUCCache>* cache,
|
||||
size_t n_classes);
|
||||
|
||||
std::pair<float, uint32_t>
|
||||
GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
|
||||
|
||||
@ -277,7 +277,7 @@ class TestDistributedGPU:
|
||||
X = dask_cudf.from_dask_dataframe(dd.from_dask_array(X_))
|
||||
y = dask_cudf.from_dask_dataframe(dd.from_dask_array(y_))
|
||||
w = dask_cudf.from_dask_dataframe(dd.from_dask_array(w_))
|
||||
run_dask_classifier(X, y, w, model, client, 10)
|
||||
run_dask_classifier(X, y, w, model, "gpu_hist", client, 10)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask())
|
||||
@pytest.mark.skipif(**tm.no_dask_cuda())
|
||||
|
||||
@ -317,6 +317,7 @@ def run_dask_classifier(
|
||||
y: xgb.dask._DaskCollection,
|
||||
w: xgb.dask._DaskCollection,
|
||||
model: str,
|
||||
tree_method: Optional[str],
|
||||
client: "Client",
|
||||
n_classes,
|
||||
) -> None:
|
||||
@ -324,11 +325,11 @@ def run_dask_classifier(
|
||||
|
||||
if model == "boosting":
|
||||
classifier = xgb.dask.DaskXGBClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric=metric
|
||||
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
|
||||
)
|
||||
else:
|
||||
classifier = xgb.dask.DaskXGBRFClassifier(
|
||||
verbosity=1, n_estimators=2, eval_metric=metric
|
||||
verbosity=1, n_estimators=2, eval_metric=metric, tree_method=tree_method
|
||||
)
|
||||
|
||||
assert classifier._estimator_type == "classifier"
|
||||
@ -397,12 +398,12 @@ def run_dask_classifier(
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
run_dask_classifier(X, y, w, model, client, 10)
|
||||
run_dask_classifier(X, y, w, model, None, client, 10)
|
||||
|
||||
y_bin = y.copy()
|
||||
y_bin[y > 5] = 1.0
|
||||
y_bin[y <= 5] = 0.0
|
||||
run_dask_classifier(X, y_bin, w, model, client, 2)
|
||||
run_dask_classifier(X, y_bin, w, model, None, client, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
@ -568,22 +569,26 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
|
||||
# multiclass
|
||||
X_, y_ = make_classification(
|
||||
n_samples=n_samples,
|
||||
n_classes=10,
|
||||
n_classes=n_workers,
|
||||
n_informative=n_features,
|
||||
n_redundant=0,
|
||||
n_repeated=0
|
||||
)
|
||||
for i in range(y_.shape[0]):
|
||||
y_[i] = i % n_workers
|
||||
X = dd.from_array(X_, chunksize=10)
|
||||
y = dd.from_array(y_, chunksize=10)
|
||||
|
||||
n_samples = n_workers - 1
|
||||
valid_X_, valid_y_ = make_classification(
|
||||
n_samples=n_samples,
|
||||
n_classes=10,
|
||||
n_classes=n_workers,
|
||||
n_informative=n_features,
|
||||
n_redundant=0,
|
||||
n_repeated=0
|
||||
)
|
||||
for i in range(valid_y_.shape[0]):
|
||||
valid_y_[i] = i % n_workers
|
||||
valid_X = dd.from_array(valid_X_, chunksize=n_samples)
|
||||
valid_y = dd.from_array(valid_y_, chunksize=n_samples)
|
||||
|
||||
@ -594,9 +599,9 @@ def run_empty_dmatrix_auc(client: "Client", tree_method: str, n_workers: int) ->
|
||||
|
||||
|
||||
def test_empty_dmatrix_auc() -> None:
|
||||
with LocalCluster(n_workers=2) as cluster:
|
||||
with LocalCluster(n_workers=8) as cluster:
|
||||
with Client(cluster) as client:
|
||||
run_empty_dmatrix_auc(client, "hist", 2)
|
||||
run_empty_dmatrix_auc(client, "hist", 8)
|
||||
|
||||
|
||||
def run_auc(client: "Client", tree_method: str) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user