finished histogram.cu
This commit is contained in:
@@ -7,7 +7,12 @@
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
#include "../../common/device_helpers.cuh"
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
#include "../../common/device_helpers.hip.h"
|
||||
#endif
|
||||
|
||||
#include "xgboost/base.h"
|
||||
#include "xgboost/context.h"
|
||||
#include "xgboost/task.h"
|
||||
@@ -140,13 +145,25 @@ void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
|
||||
});
|
||||
size_t temp_bytes = 0;
|
||||
if (tmp->empty()) {
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
|
||||
IndexFlagOp(), total_rows, stream);
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
rocprim::inclusive_scan(nullptr, temp_bytes, input_iterator, discard_write_iterator,
|
||||
total_rows, IndexFlagOp(), stream);
|
||||
#endif
|
||||
|
||||
tmp->resize(temp_bytes);
|
||||
}
|
||||
temp_bytes = tmp->size();
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator,
|
||||
discard_write_iterator, IndexFlagOp(), total_rows, stream);
|
||||
#elif defined(XGBOOST_USE_HIP)
|
||||
rocprim::inclusive_scan(tmp->data().get(), temp_bytes, input_iterator, discard_write_iterator,
|
||||
total_rows, IndexFlagOp(), stream);
|
||||
#endif
|
||||
|
||||
constexpr int kBlockSize = 256;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user