finished evaluator.cu

This commit is contained in:
amdsc21 2023-03-09 22:22:05 +01:00
parent f55243fda0
commit df42dd2c53
2 changed files with 88 additions and 0 deletions

View File

@ -7,7 +7,12 @@
#include <thrust/logical.h> // thrust::any_of
#include <thrust/sort.h> // thrust::stable_sort
#if defined(XGBOOST_USE_CUDA)
#include "../../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../../common/device_helpers.hip.h"
#endif
#include "../../common/hist_util.h" // common::HistogramCuts
#include "evaluate_splits.cuh"
#include "xgboost/data.h"
@ -30,6 +35,7 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
// This condition avoids sort-based split function calls if the users want
// onehot-encoding-based splits.
// For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x.
#if defined(XGBOOST_USE_CUDA)
need_sort_histogram_ =
thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
auto idx = i - 1;
@ -40,14 +46,32 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
}
return false;
});
#elif defined(XGBOOST_USE_HIP)
need_sort_histogram_ =
thrust::any_of(thrust::hip::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) {
auto idx = i - 1;
if (common::IsCat(ft, idx)) {
auto n_bins = ptrs[i] - ptrs[idx];
bool use_sort = !common::UseOneHot(n_bins, to_onehot);
return use_sort;
}
return false;
});
#endif
node_categorical_storage_size_ =
common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1);
CHECK_NE(node_categorical_storage_size_, 0);
split_cats_.resize(node_categorical_storage_size_);
h_split_cats_.resize(node_categorical_storage_size_);
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(
hipMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));
#endif
cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
sort_input_.resize(cat_sorted_idx_.size());
@ -59,11 +83,20 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
auto d_fidxes = dh::ToSpan(feature_idx_);
auto it = thrust::make_counting_iterator(0ul);
auto values = cuts.cut_values_.ConstDeviceSpan();
#if defined(XGBOOST_USE_CUDA)
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(),
[=] XGBOOST_DEVICE(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return fidx;
});
#elif defined(XGBOOST_USE_HIP)
thrust::transform(thrust::hip::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(),
[=] XGBOOST_DEVICE(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return fidx;
});
#endif
}
}
@ -77,6 +110,8 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
auto it = thrust::make_counting_iterator(0u);
auto d_feature_idx = dh::ToSpan(feature_idx_);
auto total_bins = shared_inputs.feature_values.size();
#if defined(XGBOOST_USE_CUDA)
thrust::transform(thrust::cuda::par(alloc), it, it + data.size(), dh::tbegin(data),
[=] XGBOOST_DEVICE(uint32_t i) {
auto const &input = d_inputs[i / total_bins];
@ -90,10 +125,27 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
}
return thrust::make_tuple(i, 0.0f);
});
#elif defined(XGBOOST_USE_HIP)
thrust::transform(thrust::hip::par(alloc), it, it + data.size(), dh::tbegin(data),
[=] XGBOOST_DEVICE(uint32_t i) {
auto const &input = d_inputs[i / total_bins];
auto j = i % total_bins;
auto fidx = d_feature_idx[j];
if (common::IsCat(shared_inputs.feature_types, fidx)) {
auto grad =
shared_inputs.rounding.ToFloatingPoint(input.gradient_histogram[j]);
auto lw = evaluator.CalcWeightCat(shared_inputs.param, grad);
return thrust::make_tuple(i, lw);
}
return thrust::make_tuple(i, 0.0f);
});
#endif
// Sort an array segmented according to
// - nodes
// - features within each node
// - gradients within each feature
#if defined(XGBOOST_USE_CUDA)
thrust::stable_sort_by_key(thrust::cuda::par(alloc), dh::tbegin(data), dh::tend(data),
dh::tbegin(sorted_idx),
[=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) {
@ -124,6 +176,38 @@ common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
}
return li < ri;
});
#elif defined(XGBOOST_USE_HIP)
thrust::stable_sort_by_key(thrust::hip::par(alloc), dh::tbegin(data), dh::tend(data),
dh::tbegin(sorted_idx),
[=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) {
auto li = thrust::get<0>(l);
auto ri = thrust::get<0>(r);
auto l_node = li / total_bins;
auto r_node = ri / total_bins;
if (l_node != r_node) {
return l_node < r_node; // not the same node
}
li = li % total_bins;
ri = ri % total_bins;
auto lfidx = d_feature_idx[li];
auto rfidx = d_feature_idx[ri];
if (lfidx != rfidx) {
return lfidx < rfidx; // not the same feature
}
if (common::IsCat(shared_inputs.feature_types, lfidx)) {
auto lw = thrust::get<1>(l);
auto rw = thrust::get<1>(r);
return lw < rw;
}
return li < ri;
});
#endif
return dh::ToSpan(cat_sorted_idx_);
}

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "evaluator.cu"
#endif