finished histogram.cu

This commit is contained in:
amdsc21
2023-03-09 21:28:37 +01:00
parent f67e7de7ef
commit 0ed5d3c849
6 changed files with 69 additions and 2 deletions

View File

@@ -7,7 +7,12 @@
#include <limits>
#include <vector>
#if defined(XGBOOST_USE_CUDA)
#include "../../common/device_helpers.cuh"
#elif defined(XGBOOST_USE_HIP)
#include "../../common/device_helpers.hip.h"
#endif
#include "xgboost/base.h"
#include "xgboost/context.h"
#include "xgboost/task.h"
@@ -140,13 +145,25 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
});
size_t temp_bytes = 0;
if (tmp->empty()) {
#if defined(XGBOOST_USE_CUDA)
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
IndexFlagOp(), total_rows, stream);
#elif defined(XGBOOST_USE_HIP)
rocprim::inclusive_scan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
total_rows, IndexFlagOp(), stream);
#endif
tmp->resize(temp_bytes);
}
temp_bytes = tmp->size();
#if defined(XGBOOST_USE_CUDA)
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
discard_write_iterator, IndexFlagOp(), total_rows, stream);
#elif defined(XGBOOST_USE_HIP)
rocprim::inclusive_scan(tmp->data().get(), temp_bytes, input_iterator, discard_write_iterator,
total_rows, IndexFlagOp(), stream);
#endif
constexpr int kBlockSize = 256;