finish rank_obj.cu

This commit is contained in:
amdsc21 2023-03-10 06:29:08 +01:00
parent 968a1db4c0
commit 41407850d5
3 changed files with 79 additions and 16 deletions

View File

@ -12,6 +12,6 @@ DMLC_REGISTRY_FILE_TAG(rank_obj);
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#if !defined(XGBOOST_USE_CUDA) && !defined(XGBOOST_USE_HIP)
#include "rank_obj.cu"
#endif // XGBOOST_USE_CUDA
#endif // XGBOOST_USE_CUDA && XGBOOST_USE_HIP

View File

@ -25,12 +25,23 @@
#include <cub/util_allocator.cuh>
#include "../common/device_helpers.cuh"
#elif defined(__HIP_PLATFORM_AMD__)
#include <thrust/sort.h>
#include <thrust/gather.h>
#include <thrust/iterator/discard_iterator.h>
#include <thrust/random/uniform_int_distribution.h>
#include <thrust/random/linear_congruential_engine.h>
#include <hipcub/util_allocator.hpp>
#include "../common/device_helpers.hip.h"
#endif
namespace xgboost {
namespace obj {
#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST)
#if (defined(XGBOOST_USE_CUDA) || defined(XGBOOST_USE_HIP)) && !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(rank_obj_gpu);
#endif // defined(XGBOOST_USE_CUDA)
@ -47,7 +58,7 @@ struct LambdaRankParam : public XGBoostParameter<LambdaRankParam> {
}
};
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
// Helper functions
template <typename T>
@ -118,7 +129,7 @@ class PairwiseLambdaWeightComputer {
return "rank:pairwise";
}
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
PairwiseLambdaWeightComputer(const bst_float*,
const bst_float*,
const dh::SegmentSorter<float>&) {}
@ -137,7 +148,7 @@ class PairwiseLambdaWeightComputer {
#endif
};
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
class BaseLambdaWeightMultiplier {
public:
BaseLambdaWeightMultiplier(const dh::SegmentSorter<float> &segment_label_sorter,
@ -209,12 +220,12 @@ class IndexablePredictionSorter {
// beta version: NDCG lambda rank
class NDCGLambdaWeightComputer
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
: public IndexablePredictionSorter
#endif
{
public:
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
// This function object computes the item's DCG value
class ComputeItemDCG : public thrust::unary_function<uint32_t, float> {
public:
@ -281,6 +292,7 @@ class NDCGLambdaWeightComputer
dh::XGBCachingDeviceAllocator<char> alloc;
// Compute each elements DCG values and reduce them across groups concurrently.
#if defined(XGBOOST_USE_CUDA)
auto end_range =
thrust::reduce_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
@ -293,6 +305,20 @@ class NDCGLambdaWeightComputer
group_segments)),
thrust::make_discard_iterator(), // We don't care for the group indices
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
#elif defined(XGBOOST_USE_HIP)
auto end_range =
thrust::reduce_by_key(thrust::hip::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
thrust::make_transform_iterator(
// The indices need not be sequential within a group, as we care only
// about the sum of items DCG values within a group
dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()),
ComputeItemDCG(segment_label_sorter.GetItemsSpan(),
segment_label_sorter.GetGroupsSpan(),
group_segments)),
thrust::make_discard_iterator(), // We don't care for the group indices
dgroup_dcg_.begin()); // Sum of the item's DCG values in the group
#endif
CHECK_EQ(static_cast<unsigned>(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size());
}
@ -368,7 +394,7 @@ class NDCGLambdaWeightComputer
return delta;
}
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
dh::caching_device_vector<float> dgroup_dcg_;
// This computes the adjustment to the weight
const NDCGLambdaWeightMultiplier weight_multiplier_;
@ -376,7 +402,7 @@ class NDCGLambdaWeightComputer
};
class MAPLambdaWeightComputer
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
: public IndexablePredictionSorter
#endif
{
@ -417,7 +443,7 @@ class MAPLambdaWeightComputer
private:
template <typename T>
XGBOOST_DEVICE inline static void Swap(T &v0, T &v1) {
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
thrust::swap(v0, v1);
#else
std::swap(v0, v1);
@ -504,7 +530,7 @@ class MAPLambdaWeightComputer
}
}
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
MAPLambdaWeightComputer(const bst_float *dpreds,
const bst_float *dlabels,
const dh::SegmentSorter<float> &segment_label_sorter)
@ -545,10 +571,17 @@ class MAPLambdaWeightComputer
// This is required for computing the accumulated precisions
const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan();
// Data segmented into different groups...
#if defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
dhits.begin(), // Input value
dhits.begin()); // In-place scan
#elif defined(XGBOOST_USE_HIP)
thrust::inclusive_scan_by_key(thrust::hip::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
dhits.begin(), // Input value
dhits.begin()); // In-place scan
#endif
// Compute accumulated precisions for each item, assuming positive and
// negative instances are missing.
@ -574,10 +607,17 @@ class MAPLambdaWeightComputer
// Lastly, compute the accumulated precisions for all the items segmented by groups.
// The precisions are accumulated within each group
#if defined(XGBOOST_USE_CUDA)
thrust::inclusive_scan_by_key(thrust::cuda::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
this->dmap_stats_.begin(), // Input map stats
this->dmap_stats_.begin()); // In-place scan and output here
#elif defined(XGBOOST_USE_HIP)
thrust::inclusive_scan_by_key(thrust::hip::par(alloc),
dh::tcbegin(group_segments), dh::tcend(group_segments),
this->dmap_stats_.begin(), // Input map stats
this->dmap_stats_.begin()); // In-place scan and output here
#endif
}
inline const common::Span<const MAPStats> GetMapStatsSpan() const {
@ -625,7 +665,7 @@ class MAPLambdaWeightComputer
#endif
};
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
class SortedLabelList : dh::SegmentSorter<float> {
private:
const LambdaRankParam &param_; // Objective configuration
@ -670,7 +710,13 @@ class SortedLabelList : dh::SegmentSorter<float> {
auto wmultiplier = weight_computer.GetWeightMultiplier();
int device_id = -1;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaGetDevice(&device_id));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipGetDevice(&device_id));
#endif
// For each instance in the group, compute the gradient pair concurrently
dh::LaunchN(niter, nullptr, [=] __device__(uint32_t idx) {
// First, determine the group 'idx' belongs to
@ -723,7 +769,12 @@ class SortedLabelList : dh::SegmentSorter<float> {
bst_float h = thrust::max(p * (1.0f - p), eps);
// Rescale each gradient and hessian so that the group has a weighted constant
#if defined(XGBOOST_USE_CUDA)
float scale = __frcp_ru(niter / total_items);
#elif defined(XGBOOST_USE_HIP)
float scale = __frcp_rn(niter / total_items);
#endif
if (fix_list_weight != 0.0f) {
scale *= fix_list_weight / total_group_items;
}
@ -741,7 +792,11 @@ class SortedLabelList : dh::SegmentSorter<float> {
});
// Wait until the computations done by the kernel is complete
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaStreamSynchronize(nullptr));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipStreamSynchronize(nullptr));
#endif
}
};
#endif
@ -768,7 +823,7 @@ class LambdaRankObj : public ObjFunction {
<< "labels size: " << info.labels.Size() << ", "
<< "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back());
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
// Check if we have a GPU assignment; else, revert back to CPU
auto device = ctx_->gpu_id;
if (device >= 0) {
@ -777,7 +832,7 @@ class LambdaRankObj : public ObjFunction {
// Revert back to CPU
#endif
ComputeGradientsOnCPU(preds, info, iter, out_gpair, gptr);
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
}
#endif
}
@ -898,7 +953,7 @@ class LambdaRankObj : public ObjFunction {
exc.Rethrow();
}
#if defined(__CUDACC__)
#if defined(__CUDACC__) || defined(__HIP_PLATFORM_AMD__)
void ComputeGradientsOnGPU(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
@ -907,7 +962,11 @@ class LambdaRankObj : public ObjFunction {
LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU.";
auto device = ctx_->gpu_id;
#if defined(XGBOOST_USE_CUDA)
dh::safe_cuda(cudaSetDevice(device));
#elif defined(XGBOOST_USE_HIP)
dh::safe_cuda(hipSetDevice(device));
#endif
bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr);

View File

@ -0,0 +1,4 @@
#if defined(XGBOOST_USE_HIP)
#include "rank_obj.cu"
#endif