Convert labels into tensor. (#7456)

* Add a new ctor to tensor for `initilizer_list`.
* Change labels from host device vector to tensor.
* Rename the field from `labels_` to `labels` since it's a public member.
This commit is contained in:
Jiaming Yuan
2021-12-17 00:58:35 +08:00
committed by GitHub
parent 6f8a4633b7
commit 5b1161bb64
35 changed files with 319 additions and 258 deletions

View File

@@ -89,12 +89,12 @@ std::tuple<double, double, double>
GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
int32_t device, common::Span<size_t const> d_sorted_idx,
Fn area_fn, std::shared_ptr<DeviceAUCCache> cache) {
auto labels = info.labels_.ConstDeviceSpan();
auto labels = info.labels.View(device);
auto weights = info.weights_.ConstDeviceSpan();
dh::safe_cuda(cudaSetDevice(device));
CHECK(!labels.empty());
CHECK_EQ(labels.size(), predts.size());
CHECK_NE(labels.Size(), 0);
CHECK_EQ(labels.Size(), predts.size());
/**
* Linear scan
@@ -103,7 +103,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
auto get_fp_tp = [=]XGBOOST_DEVICE(size_t i) {
size_t idx = d_sorted_idx[i];
float label = labels[idx];
float label = labels(idx);
float w = get_weight[d_sorted_idx[i]];
float fp = (1.0 - label) * w;
@@ -332,10 +332,10 @@ double GPUMultiClassAUCOVR(common::Span<float const> predts,
// Index is sorted within class.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels_.ConstDeviceSpan();
auto labels = info.labels.View(device);
auto weights = info.weights_.ConstDeviceSpan();
size_t n_samples = labels.size();
size_t n_samples = labels.Shape(0);
if (n_samples == 0) {
dh::TemporaryArray<double> resutls(n_classes * 4, 0.0f);
@@ -360,7 +360,7 @@ double GPUMultiClassAUCOVR(common::Span<float const> predts,
size_t class_id = i / n_samples;
// labels is a vector of size n_samples.
float label = labels[idx % n_samples] == class_id;
float label = labels(idx % n_samples) == class_id;
float w = get_weight[d_sorted_idx[i] % n_samples];
float fp = (1.0 - label) * w;
@@ -528,10 +528,10 @@ GPURankingAUC(common::Span<float const> predts, MetaInfo const &info,
/**
* Sort the labels
*/
auto d_labels = info.labels_.ConstDeviceSpan();
auto d_labels = info.labels.View(device);
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::SegmentedArgSort<false>(d_labels, d_group_ptr, d_sorted_idx);
dh::SegmentedArgSort<false>(d_labels.Values(), d_group_ptr, d_sorted_idx);
auto d_weights = info.weights_.ConstDeviceSpan();
@@ -631,19 +631,19 @@ GPUBinaryPRAUC(common::Span<float const> predts, MetaInfo const &info,
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
dh::ArgSort<false>(predts, d_sorted_idx);
auto labels = info.labels_.ConstDeviceSpan();
auto labels = info.labels.View(device);
auto d_weights = info.weights_.ConstDeviceSpan();
auto get_weight = OptionalWeights{d_weights};
auto it = dh::MakeTransformIterator<Pair>(
thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) {
auto w = get_weight[d_sorted_idx[i]];
return thrust::make_pair(labels[d_sorted_idx[i]] * w,
(1.0f - labels[d_sorted_idx[i]]) * w);
return thrust::make_pair(labels(d_sorted_idx[i]) * w,
(1.0f - labels(d_sorted_idx[i])) * w);
});
dh::XGBCachingDeviceAllocator<char> alloc;
double total_pos, total_neg;
thrust::tie(total_pos, total_neg) =
thrust::reduce(thrust::cuda::par(alloc), it, it + labels.size(),
thrust::reduce(thrust::cuda::par(alloc), it, it + labels.Size(),
Pair{0.0, 0.0}, PairPlus<double, double>{});
if (total_pos <= 0.0 || total_neg <= 0.0) {
@@ -679,7 +679,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
/**
* Get total positive/negative
*/
auto labels = info.labels_.ConstDeviceSpan();
auto labels = info.labels.View(device);
auto n_samples = info.num_row_;
dh::caching_device_vector<Pair> totals(n_classes);
auto key_it =
@@ -693,7 +693,7 @@ double GPUMultiClassPRAUC(common::Span<float const> predts,
auto idx = d_sorted_idx[i] % n_samples;
auto w = get_weight[idx];
auto class_id = i / n_samples;
auto y = labels[idx] == class_id;
auto y = labels(idx) == class_id;
return thrust::make_pair(y * w, (1.0f - y) * w);
});
dh::XGBCachingDeviceAllocator<char> alloc;
@@ -726,7 +726,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
*/
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
auto labels = info.labels_.ConstDeviceSpan();
auto labels = info.labels.View(device);
auto weights = info.weights_.ConstDeviceSpan();
uint32_t n_groups = static_cast<uint32_t>(info.group_ptr_.size() - 1);
@@ -734,7 +734,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
/**
* Linear scan
*/
size_t n_samples = labels.size();
size_t n_samples = labels.Shape(0);
dh::caching_device_vector<double> d_auc(n_groups, 0);
auto get_weight = OptionalWeights{weights};
auto d_fptp = dh::ToSpan(cache->fptp);
@@ -742,7 +742,7 @@ GPURankingPRAUCImpl(common::Span<float const> predts, MetaInfo const &info,
size_t idx = d_sorted_idx[i];
size_t group_id = dh::SegmentId(d_group_ptr, idx);
float label = labels[idx];
float label = labels(idx);
float w = get_weight[group_id];
float fp = (1.0 - label) * w;
@@ -860,9 +860,9 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
dh::SegmentedArgSort<false>(predts, d_group_ptr, d_sorted_idx);
dh::XGBDeviceAllocator<char> alloc;
auto labels = info.labels_.ConstDeviceSpan();
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels),
dh::tend(labels), PRAUCLabelInvalid{})) {
auto labels = info.labels.View(device);
if (thrust::any_of(thrust::cuda::par(alloc), dh::tbegin(labels.Values()),
dh::tend(labels.Values()), PRAUCLabelInvalid{})) {
InvalidLabels();
}
/**
@@ -881,7 +881,7 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
auto g = dh::SegmentId(d_group_ptr, i);
w = d_weights[g];
}
auto y = labels[i];
auto y = labels(i);
return thrust::make_pair(y * w, (1.0 - y) * w);
});
thrust::reduce_by_key(thrust::cuda::par(alloc), key_it,
@@ -899,7 +899,7 @@ GPURankingPRAUC(common::Span<float const> predts, MetaInfo const &info,
return detail::CalcDeltaPRAUC(fp_prev, fp, tp_prev, tp,
d_totals[group_id].first);
};
return GPURankingPRAUCImpl(predts, info, d_group_ptr, n_groups, cache, fn);
return GPURankingPRAUCImpl(predts, info, d_group_ptr, device, cache, fn);
}
} // namespace metric
} // namespace xgboost