xgboost/src/tree/gpu_hist/row_partitioner.cu
2023-10-17 12:42:37 -07:00

47 lines
1.3 KiB
Plaintext

/*!
* Copyright 2017-2022 XGBoost contributors
*/
#include <thrust/iterator/discard_iterator.h>
#include <thrust/iterator/transform_output_iterator.h>
#include <thrust/sequence.h>
#include <vector>
#include "../../common/device_helpers.cuh"
#include "row_partitioner.cuh"
namespace xgboost {
namespace tree {
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
}
RowPartitioner::~RowPartitioner() {
dh::safe_cuda(cudaSetDevice(device_idx_));
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(bst_node_t nidx) {
auto segment = ridx_segments_.at(nidx).segment;
return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size());
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return dh::ToSpan(ridx_);
}
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(bst_node_t nidx) {
auto span = GetRows(nidx);
std::vector<RowIndexT> rows(span.size());
dh::CopyDeviceSpanToVector(&rows, span);
return rows;
}
}; // namespace tree
}; // namespace xgboost