Upgrade clang-tidy on CI. (#5469)

* Correct all clang-tidy errors.
* Upgrade clang-tidy to 10 on CI.

Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan
2020-04-05 04:42:29 +08:00
committed by GitHub
parent 30e94ddd04
commit 0012f2ef93
107 changed files with 932 additions and 903 deletions

View File

@@ -153,7 +153,7 @@ ExternalMemoryNoSampling::ExternalMemoryNoSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param)
: batch_param_(batch_param),
page_(new EllpackPageImpl(batch_param.gpu_id, page->cuts_, page->is_dense,
page_(new EllpackPageImpl(batch_param.gpu_id, page->Cuts(), page->is_dense,
page->row_stride, n_rows)) {}
GradientBasedSample ExternalMemoryNoSampling::Sample(common::Span<GradientPair> gpair,
@@ -201,7 +201,6 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
// Count the sampled rows.
size_t sample_rows = thrust::count_if(dh::tbegin(gpair), dh::tend(gpair), IsNonZero());
size_t n_rows = dmat->Info().num_row_;
// Compact gradient pairs.
gpair_.resize(sample_rows);
@@ -219,7 +218,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(
batch_param_.gpu_id, original_page_->cuts_, original_page_->is_dense,
batch_param_.gpu_id, original_page_->Cuts(), original_page_->is_dense,
original_page_->row_stride, sample_rows));
// Compact the ELLPACK pages into the single sample page.
@@ -299,7 +298,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, original_page_->cuts_,
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, original_page_->Cuts(),
original_page_->is_dense,
original_page_->row_stride, sample_rows));

View File

@@ -64,54 +64,55 @@ void RowPartitioner::SortPosition(common::Span<bst_node_t> position,
cub::DeviceScan::ExclusiveSum(temp_storage.data().get(), temp_storage_bytes,
in_itr, out_itr, position.size(), stream);
}
RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx(device_idx) {
dh::safe_cuda(cudaSetDevice(device_idx));
ridx_a.resize(num_rows);
ridx_b.resize(num_rows);
position_a.resize(num_rows);
position_b.resize(num_rows);
ridx = dh::DoubleBuffer<RowIndexT>{&ridx_a, &ridx_b};
position = dh::DoubleBuffer<bst_node_t>{&position_a, &position_b};
ridx_segments.emplace_back(Segment(0, num_rows));
: device_idx_(device_idx) {
dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_a_.resize(num_rows);
ridx_b_.resize(num_rows);
position_a_.resize(num_rows);
position_b_.resize(num_rows);
ridx_ = dh::DoubleBuffer<RowIndexT>{&ridx_a_, &ridx_b_};
position_ = dh::DoubleBuffer<bst_node_t>{&position_a_, &position_b_};
ridx_segments_.emplace_back(Segment(0, num_rows));
thrust::sequence(
thrust::device_pointer_cast(ridx.CurrentSpan().data()),
thrust::device_pointer_cast(ridx.CurrentSpan().data() + ridx.Size()));
thrust::device_pointer_cast(ridx_.CurrentSpan().data()),
thrust::device_pointer_cast(ridx_.CurrentSpan().data() + ridx_.Size()));
thrust::fill(
thrust::device_pointer_cast(position.Current()),
thrust::device_pointer_cast(position.Current() + position.Size()), 0);
left_counts.resize(256);
thrust::fill(left_counts.begin(), left_counts.end(), 0);
streams.resize(2);
for (auto& stream : streams) {
thrust::device_pointer_cast(position_.Current()),
thrust::device_pointer_cast(position_.Current() + position_.Size()), 0);
left_counts_.resize(256);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
streams_.resize(2);
for (auto& stream : streams_) {
dh::safe_cuda(cudaStreamCreate(&stream));
}
}
RowPartitioner::~RowPartitioner() {
dh::safe_cuda(cudaSetDevice(device_idx));
for (auto& stream : streams) {
dh::safe_cuda(cudaSetDevice(device_idx_));
for (auto& stream : streams_) {
dh::safe_cuda(cudaStreamDestroy(stream));
}
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows(
bst_node_t nidx) {
auto segment = ridx_segments.at(nidx);
auto segment = ridx_segments_.at(nidx);
// Return empty span here as a valid result
// Will error if we try to construct a span from a pointer with size 0
if (segment.Size() == 0) {
return common::Span<const RowPartitioner::RowIndexT>();
}
return ridx.CurrentSpan().subspan(segment.begin, segment.Size());
return ridx_.CurrentSpan().subspan(segment.begin, segment.Size());
}
common::Span<const RowPartitioner::RowIndexT> RowPartitioner::GetRows() {
return ridx.CurrentSpan();
return ridx_.CurrentSpan();
}
common::Span<const bst_node_t> RowPartitioner::GetPosition() {
return position.CurrentSpan();
return position_.CurrentSpan();
}
std::vector<RowPartitioner::RowIndexT> RowPartitioner::GetRowsHost(
bst_node_t nidx) {
@@ -135,22 +136,22 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment,
cudaStream_t stream) {
SortPosition(
// position_in
common::Span<bst_node_t>(position.Current() + segment.begin,
common::Span<bst_node_t>(position_.Current() + segment.begin,
segment.Size()),
// position_out
common::Span<bst_node_t>(position.other() + segment.begin,
segment.Size()),
common::Span<bst_node_t>(position_.Other() + segment.begin,
segment.Size()),
// row index in
common::Span<RowIndexT>(ridx.Current() + segment.begin, segment.Size()),
common::Span<RowIndexT>(ridx_.Current() + segment.begin, segment.Size()),
// row index out
common::Span<RowIndexT>(ridx.other() + segment.begin, segment.Size()),
common::Span<RowIndexT>(ridx_.Other() + segment.begin, segment.Size()),
left_nidx, right_nidx, d_left_count, stream);
// Copy back key/value
const auto d_position_current = position.Current() + segment.begin;
const auto d_position_other = position.other() + segment.begin;
const auto d_ridx_current = ridx.Current() + segment.begin;
const auto d_ridx_other = ridx.other() + segment.begin;
dh::LaunchN(device_idx, segment.Size(), stream, [=] __device__(size_t idx) {
const auto d_position_current = position_.Current() + segment.begin;
const auto d_position_other = position_.Other() + segment.begin;
const auto d_ridx_current = ridx_.Current() + segment.begin;
const auto d_ridx_other = ridx_.Other() + segment.begin;
dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) {
d_position_current[idx] = d_position_other[idx];
d_ridx_current[idx] = d_ridx_other[idx];
});

View File

@@ -36,7 +36,7 @@ class RowPartitioner {
static constexpr bst_node_t kIgnoredTreePosition = -1;
private:
int device_idx;
int device_idx_;
/*! \brief In here if you want to find the rows belong to a node nid, first you need to
* get the indices segment from ridx_segments[nid], then get the row index that
* represents position of row in input data X. `RowPartitioner::GetRows` would be a
@@ -45,22 +45,22 @@ class RowPartitioner {
* node id -> segment -> indices of rows belonging to node
*/
/*! \brief Range of row index for each node, pointers into ridx below. */
std::vector<Segment> ridx_segments;
dh::caching_device_vector<RowIndexT> ridx_a;
dh::caching_device_vector<RowIndexT> ridx_b;
dh::caching_device_vector<bst_node_t> position_a;
dh::caching_device_vector<bst_node_t> position_b;
std::vector<Segment> ridx_segments_;
dh::caching_device_vector<RowIndexT> ridx_a_;
dh::caching_device_vector<RowIndexT> ridx_b_;
dh::caching_device_vector<bst_node_t> position_a_;
dh::caching_device_vector<bst_node_t> position_b_;
/*! \brief mapping for node id -> rows.
* This looks like:
* node id | 1 | 2 |
* rows idx | 3, 5, 1 | 13, 31 |
*/
dh::DoubleBuffer<RowIndexT> ridx;
dh::DoubleBuffer<RowIndexT> ridx_;
/*! \brief mapping for row -> node id. */
dh::DoubleBuffer<bst_node_t> position;
dh::DoubleBuffer<bst_node_t> position_;
dh::caching_device_vector<int64_t>
left_counts; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams;
left_counts_; // Useful to keep a bunch of zeroed memory for sort position
std::vector<cudaStream_t> streams_;
public:
RowPartitioner(int device_idx, size_t num_rows);
@@ -108,19 +108,19 @@ class RowPartitioner {
template <typename UpdatePositionOpT>
void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx,
bst_node_t right_nidx, UpdatePositionOpT op) {
dh::safe_cuda(cudaSetDevice(device_idx));
Segment segment = ridx_segments.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx.CurrentSpan();
auto d_position = position.CurrentSpan();
if (left_counts.size() <= nidx) {
left_counts.resize((nidx * 2) + 1);
thrust::fill(left_counts.begin(), left_counts.end(), 0);
dh::safe_cuda(cudaSetDevice(device_idx_));
Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx
auto d_ridx = ridx_.CurrentSpan();
auto d_position = position_.CurrentSpan();
if (left_counts_.size() <= nidx) {
left_counts_.resize((nidx * 2) + 1);
thrust::fill(left_counts_.begin(), left_counts_.end(), 0);
}
// Now we divide the row segment into left and right node.
int64_t* d_left_count = left_counts.data().get() + nidx;
int64_t* d_left_count = left_counts_.data().get() + nidx;
// Launch 1 thread for each row
dh::LaunchN<1, 128>(device_idx, segment.Size(), [=] __device__(size_t idx) {
dh::LaunchN<1, 128>(device_idx_, segment.Size(), [=] __device__(size_t idx) {
// LaunchN starts from zero, so we restore the row index by adding segment.begin
idx += segment.begin;
RowIndexT ridx = d_ridx[idx];
@@ -132,19 +132,19 @@ class RowPartitioner {
// Overlap device to host memory copy (left_count) with sort
int64_t left_count;
dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t),
cudaMemcpyDeviceToHost, streams[0]));
cudaMemcpyDeviceToHost, streams_[0]));
SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count,
streams[1]);
streams_[1]);
dh::safe_cuda(cudaStreamSynchronize(streams[0]));
dh::safe_cuda(cudaStreamSynchronize(streams_[0]));
CHECK_LE(left_count, segment.Size());
CHECK_GE(left_count, 0);
ridx_segments.resize(std::max(int(ridx_segments.size()),
std::max(left_nidx, right_nidx) + 1));
ridx_segments[left_nidx] =
ridx_segments_.resize(std::max(static_cast<bst_node_t>(ridx_segments_.size()),
std::max(left_nidx, right_nidx) + 1));
ridx_segments_[left_nidx] =
Segment(segment.begin, segment.begin + left_count);
ridx_segments[right_nidx] =
ridx_segments_[right_nidx] =
Segment(segment.begin + left_count, segment.end);
}
@@ -159,9 +159,9 @@ class RowPartitioner {
*/
template <typename FinalisePositionOpT>
void FinalisePosition(FinalisePositionOpT op) {
auto d_position = position.Current();
const auto d_ridx = ridx.Current();
dh::LaunchN(device_idx, position.Size(), [=] __device__(size_t idx) {
auto d_position = position_.Current();
const auto d_ridx = ridx_.Current();
dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) {
auto position = d_position[idx];
RowIndexT ridx = d_ridx[idx];
bst_node_t new_position = op(ridx, position);
@@ -189,10 +189,10 @@ class RowPartitioner {
/** \brief Used to demarcate a contiguous set of row indices associated with
* some tree node. */
struct Segment {
size_t begin;
size_t end;
size_t begin { 0 };
size_t end { 0 };
Segment() : begin{0}, end{0} {}
Segment() = default;
Segment(size_t begin, size_t end) : begin(begin), end(end) {
CHECK_GE(end, begin);