From fe8d72b50b132af6e24dfa6eb2e08d18430247e8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 31 Jan 2020 14:52:15 +0800 Subject: [PATCH] Cleanup warnings. (#5247) From clang-tidy-9 and gcc-7: Invalid case style, narrowing definition, wrong initialization order, unused variables. --- include/xgboost/model.h | 2 + src/c_api/c_api.cc | 4 +- src/data/adapter.h | 338 ++++++++++----------- tests/cpp/common/test_hist_util.cc | 4 - tests/cpp/common/test_threading_utils.cc | 165 +++++----- tests/cpp/data/test_adapter.cc | 1 - tests/cpp/data/test_simple_dmatrix.cc | 4 +- tests/cpp/data/test_sparse_page_dmatrix.cc | 4 +- 8 files changed, 260 insertions(+), 262 deletions(-) diff --git a/include/xgboost/model.h b/include/xgboost/model.h index b1f024973..3b661ae81 100644 --- a/include/xgboost/model.h +++ b/include/xgboost/model.h @@ -15,6 +15,7 @@ namespace xgboost { class Json; struct Model { + virtual ~Model() = default; /*! * \brief load the model from a json object * \param in json object where to load the model from @@ -28,6 +29,7 @@ struct Model { }; struct Configurable { + virtual ~Configurable() = default; /*! * \brief Load configuration from JSON object * \param in JSON object containing the configuration diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 85a49f3c6..37cbd6dd2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -257,7 +257,7 @@ XGB_DLL int XGDMatrixCreateFromMat(const bst_float* data, xgboost::bst_ulong ncol, bst_float missing, DMatrixHandle* out) { API_BEGIN(); - data::DenseAdapter adapter(data, nrow, nrow * ncol, ncol); + data::DenseAdapter adapter(data, nrow, ncol); *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, 1)); API_END(); } @@ -268,7 +268,7 @@ XGB_DLL int XGDMatrixCreateFromMat_omp(const bst_float* data, // NOLINT bst_float missing, DMatrixHandle* out, int nthread) { API_BEGIN(); - data::DenseAdapter adapter(data, nrow, nrow * ncol, ncol); + data::DenseAdapter adapter(data, nrow, ncol); *out = new std::shared_ptr(DMatrix::Create(&adapter, missing, nthread)); API_END(); } diff --git a/src/data/adapter.h b/src/data/adapter.h index 896a44b94..57a0d7790 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -11,6 +11,9 @@ #include #include +#include "xgboost/base.h" +#include "xgboost/data.h" + namespace xgboost { namespace data { @@ -76,17 +79,17 @@ namespace detail { template class SingleBatchDataIter : dmlc::DataIter { public: - void BeforeFirst() override { counter = 0; } + void BeforeFirst() override { counter_ = 0; } bool Next() override { - if (counter == 0) { - counter++; + if (counter_ == 0) { + counter_++; return true; } return false; } private: - int counter{0}; + int counter_{0}; }; /** \brief Indicates this data source cannot contain meta-info such as labels, @@ -107,42 +110,42 @@ class CSRAdapterBatch : public detail::NoMetaInfo { public: Line(size_t row_idx, size_t size, const unsigned* feature_idx, const float* values) - : row_idx(row_idx), - size(size), - feature_idx(feature_idx), - values(values) {} + : row_idx_(row_idx), + size_(size), + feature_idx_(feature_idx), + values_(values) {} - size_t Size() const { return size; } + size_t Size() const { return size_; } COOTuple GetElement(size_t idx) const { - return COOTuple(row_idx, feature_idx[idx], values[idx]); + return COOTuple{row_idx_, feature_idx_[idx], values_[idx]}; } private: - size_t row_idx; - size_t size; - const unsigned* feature_idx; - const float* values; + size_t row_idx_; + size_t size_; + const unsigned* feature_idx_; + const float* values_; }; CSRAdapterBatch(const size_t* row_ptr, const unsigned* feature_idx, const float* values, size_t num_rows, size_t num_elements, size_t num_features) - : row_ptr(row_ptr), - feature_idx(feature_idx), - values(values), - num_rows(num_rows) {} + : row_ptr_(row_ptr), + feature_idx_(feature_idx), + values_(values), + num_rows_(num_rows) {} const Line GetLine(size_t idx) const { - size_t begin_offset = row_ptr[idx]; - size_t end_offset = row_ptr[idx + 1]; - return Line(idx, end_offset - begin_offset, &feature_idx[begin_offset], - &values[begin_offset]); + size_t begin_offset = row_ptr_[idx]; + size_t end_offset = row_ptr_[idx + 1]; + return Line(idx, end_offset - begin_offset, &feature_idx_[begin_offset], + &values_[begin_offset]); } - size_t Size() const { return num_rows; } + size_t Size() const { return num_rows_; } private: - const size_t* row_ptr; - const unsigned* feature_idx; - const float* values; - size_t num_rows; + const size_t* row_ptr_; + const unsigned* feature_idx_; + const float* values_; + size_t num_rows_; }; class CSRAdapter : public detail::SingleBatchDataIter { @@ -150,150 +153,146 @@ class CSRAdapter : public detail::SingleBatchDataIter { CSRAdapter(const size_t* row_ptr, const unsigned* feature_idx, const float* values, size_t num_rows, size_t num_elements, size_t num_features) - : batch(row_ptr, feature_idx, values, num_rows, num_elements, - num_features), - num_rows(num_rows), - num_columns(num_features) {} - const CSRAdapterBatch& Value() const override { return batch; } - size_t NumRows() const { return num_rows; } - size_t NumColumns() const { return num_columns; } + : batch_(row_ptr, feature_idx, values, num_rows, num_elements, + num_features), + num_rows_(num_rows), + num_columns_(num_features) {} + const CSRAdapterBatch& Value() const override { return batch_; } + size_t NumRows() const { return num_rows_; } + size_t NumColumns() const { return num_columns_; } private: - CSRAdapterBatch batch; - size_t num_rows; - size_t num_columns; + CSRAdapterBatch batch_; + size_t num_rows_; + size_t num_columns_; }; class DenseAdapterBatch : public detail::NoMetaInfo { public: - DenseAdapterBatch(const float* values, size_t num_rows, size_t num_elements, - size_t num_features) - : num_features(num_features), - num_rows(num_rows), - num_elements(num_elements), - values(values) {} + DenseAdapterBatch(const float* values, size_t num_rows, size_t num_features) + : values_(values), + num_rows_(num_rows), + num_features_(num_features) {} private: class Line { public: Line(const float* values, size_t size, size_t row_idx) - : row_idx(row_idx), size(size), values(values) {} + : row_idx_(row_idx), size_(size), values_(values) {} - size_t Size() const { return size; } + size_t Size() const { return size_; } COOTuple GetElement(size_t idx) const { - return COOTuple(row_idx, idx, values[idx]); + return COOTuple{row_idx_, idx, values_[idx]}; } private: - size_t row_idx; - size_t size; - const float* values; + size_t row_idx_; + size_t size_; + const float* values_; }; public: - size_t Size() const { return num_rows; } + size_t Size() const { return num_rows_; } const Line GetLine(size_t idx) const { - return Line(values + idx * num_features, num_features, idx); + return Line(values_ + idx * num_features_, num_features_, idx); } private: - const float* values; - size_t num_elements; - size_t num_rows; - size_t num_features; + const float* values_; + size_t num_rows_; + size_t num_features_; }; class DenseAdapter : public detail::SingleBatchDataIter { public: - DenseAdapter(const float* values, size_t num_rows, size_t num_elements, - size_t num_features) - : batch(values, num_rows, num_elements, num_features), - num_rows(num_rows), - num_columns(num_features) {} - const DenseAdapterBatch& Value() const override { return batch; } + DenseAdapter(const float* values, size_t num_rows, size_t num_features) + : batch_(values, num_rows, num_features), + num_rows_(num_rows), + num_columns_(num_features) {} + const DenseAdapterBatch& Value() const override { return batch_; } - size_t NumRows() const { return num_rows; } - size_t NumColumns() const { return num_columns; } + size_t NumRows() const { return num_rows_; } + size_t NumColumns() const { return num_columns_; } private: - DenseAdapterBatch batch; - size_t num_rows; - size_t num_columns; + DenseAdapterBatch batch_; + size_t num_rows_; + size_t num_columns_; }; class CSCAdapterBatch : public detail::NoMetaInfo { public: CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx, const float* values, size_t num_features) - : col_ptr(col_ptr), - row_idx(row_idx), - values(values), - num_features(num_features) {} + : col_ptr_(col_ptr), + row_idx_(row_idx), + values_(values), + num_features_(num_features) {} private: class Line { public: Line(size_t col_idx, size_t size, const unsigned* row_idx, const float* values) - : col_idx(col_idx), size(size), row_idx(row_idx), values(values) {} + : col_idx_(col_idx), size_(size), row_idx_(row_idx), values_(values) {} - size_t Size() const { return size; } + size_t Size() const { return size_; } COOTuple GetElement(size_t idx) const { - return COOTuple(row_idx[idx], col_idx, values[idx]); + return COOTuple{row_idx_[idx], col_idx_, values_[idx]}; } private: - size_t col_idx; - size_t size; - const unsigned* row_idx; - const float* values; + size_t col_idx_; + size_t size_; + const unsigned* row_idx_; + const float* values_; }; public: - size_t Size() const { return num_features; } + size_t Size() const { return num_features_; } const Line GetLine(size_t idx) const { - size_t begin_offset = col_ptr[idx]; - size_t end_offset = col_ptr[idx + 1]; - return Line(idx, end_offset - begin_offset, &row_idx[begin_offset], - &values[begin_offset]); + size_t begin_offset = col_ptr_[idx]; + size_t end_offset = col_ptr_[idx + 1]; + return Line(idx, end_offset - begin_offset, &row_idx_[begin_offset], + &values_[begin_offset]); } private: - const size_t* col_ptr; - const unsigned* row_idx; - const float* values; - size_t num_features; + const size_t* col_ptr_; + const unsigned* row_idx_; + const float* values_; + size_t num_features_; }; class CSCAdapter : public detail::SingleBatchDataIter { public: CSCAdapter(const size_t* col_ptr, const unsigned* row_idx, const float* values, size_t num_features, size_t num_rows) - : batch(col_ptr, row_idx, values, num_features), - num_rows(num_rows), - num_columns(num_features) {} - const CSCAdapterBatch& Value() const override { return batch; } + : batch_(col_ptr, row_idx, values, num_features), + num_rows_(num_rows), + num_columns_(num_features) {} + const CSCAdapterBatch& Value() const override { return batch_; } // JVM package sends 0 as unknown size_t NumRows() const { - return num_rows == 0 ? kAdapterUnknownSize : num_rows; + return num_rows_ == 0 ? kAdapterUnknownSize : num_rows_; } - size_t NumColumns() const { return num_columns; } + size_t NumColumns() const { return num_columns_; } private: - CSCAdapterBatch batch; - size_t num_rows; - size_t num_columns; + CSCAdapterBatch batch_; + size_t num_rows_; + size_t num_columns_; }; class DataTableAdapterBatch : public detail::NoMetaInfo { public: DataTableAdapterBatch(void** data, const char** feature_stypes, size_t num_rows, size_t num_features) - : data(data), - feature_stypes(feature_stypes), - num_features(num_features), - num_rows(num_rows) {} + : data_(data), + feature_stypes_(feature_stypes), + num_features_(num_features), + num_rows_(num_rows) {} private: enum class DTType : uint8_t { @@ -370,31 +369,31 @@ class DataTableAdapterBatch : public detail::NoMetaInfo { public: Line(DTType type, size_t size, size_t column_idx, const void* column) - : type(type), size(size), column_idx(column_idx), column(column) {} + : type_(type), size_(size), column_idx_(column_idx), column_(column) {} - size_t Size() const { return size; } + size_t Size() const { return size_; } COOTuple GetElement(size_t idx) const { - return COOTuple(idx, column_idx, DTGetValue(column, type, idx)); + return COOTuple{idx, column_idx_, DTGetValue(column_, type_, idx)}; } private: - DTType type; - size_t size; - size_t column_idx; - const void* column; + DTType type_; + size_t size_; + size_t column_idx_; + const void* column_; }; public: - size_t Size() const { return num_features; } + size_t Size() const { return num_features_; } const Line GetLine(size_t idx) const { - return Line(DTGetType(feature_stypes[idx]), num_rows, idx, data[idx]); + return Line(DTGetType(feature_stypes_[idx]), num_rows_, idx, data_[idx]); } private: - void** data; - const char** feature_stypes; - size_t num_features; - size_t num_rows; + void** data_; + const char** feature_stypes_; + size_t num_features_; + size_t num_rows_; }; class DataTableAdapter @@ -402,17 +401,17 @@ class DataTableAdapter public: DataTableAdapter(void** data, const char** feature_stypes, size_t num_rows, size_t num_features) - : batch(data, feature_stypes, num_rows, num_features), - num_rows(num_rows), - num_columns(num_features) {} - const DataTableAdapterBatch& Value() const override { return batch; } - size_t NumRows() const { return num_rows; } - size_t NumColumns() const { return num_columns; } + : batch_(data, feature_stypes, num_rows, num_features), + num_rows_(num_rows), + num_columns_(num_features) {} + const DataTableAdapterBatch& Value() const override { return batch_; } + size_t NumRows() const { return num_rows_; } + size_t NumColumns() const { return num_columns_; } private: - DataTableAdapterBatch batch; - size_t num_rows; - size_t num_columns; + DataTableAdapterBatch batch_; + size_t num_rows_; + size_t num_columns_; }; class FileAdapterBatch { @@ -421,59 +420,59 @@ class FileAdapterBatch { public: Line(size_t row_idx, const uint32_t* feature_idx, const float* value, size_t size) - : row_idx(row_idx), - feature_idx(feature_idx), - value(value), - size(size) {} + : row_idx_(row_idx), + feature_idx_(feature_idx), + value_(value), + size_(size) {} - size_t Size() { return size; } + size_t Size() { return size_; } COOTuple GetElement(size_t idx) { - float fvalue = value == nullptr ? 1.0f : value[idx]; - return COOTuple(row_idx, feature_idx[idx], fvalue); + float fvalue = value_ == nullptr ? 1.0f : value_[idx]; + return COOTuple{row_idx_, feature_idx_[idx], fvalue}; } private: - size_t row_idx; - const uint32_t* feature_idx; - const float* value; - size_t size; + size_t row_idx_; + const uint32_t* feature_idx_; + const float* value_; + size_t size_; }; FileAdapterBatch(const dmlc::RowBlock* block, size_t row_offset) - : block(block), row_offset(row_offset) {} + : block_(block), row_offset_(row_offset) {} Line GetLine(size_t idx) const { - auto begin = block->offset[idx]; - auto end = block->offset[idx + 1]; - return Line(idx + row_offset, &block->index[begin], &block->value[begin], - end - begin); + auto begin = block_->offset[idx]; + auto end = block_->offset[idx + 1]; + return Line{idx + row_offset_, &block_->index[begin], &block_->value[begin], + end - begin}; } - const float* Labels() const { return block->label; } - const float* Weights() const { return block->weight; } - const uint64_t* Qid() const { return block->qid; } + const float* Labels() const { return block_->label; } + const float* Weights() const { return block_->weight; } + const uint64_t* Qid() const { return block_->qid; } const float* BaseMargin() const { return nullptr; } - size_t Size() const { return block->size; } + size_t Size() const { return block_->size; } private: - const dmlc::RowBlock* block; - size_t row_offset; + const dmlc::RowBlock* block_; + size_t row_offset_; }; /** \brief FileAdapter wraps dmlc::parser to read files and provide access in a * common interface. */ class FileAdapter : dmlc::DataIter { public: - explicit FileAdapter(dmlc::Parser* parser) : parser(parser) {} + explicit FileAdapter(dmlc::Parser* parser) : parser_(parser) {} - const FileAdapterBatch& Value() const override { return *batch.get(); } + const FileAdapterBatch& Value() const override { return *batch_.get(); } void BeforeFirst() override { - batch.reset(); - parser->BeforeFirst(); - row_offset = 0; + batch_.reset(); + parser_->BeforeFirst(); + row_offset_ = 0; } bool Next() override { - bool next = parser->Next(); - batch.reset(new FileAdapterBatch(&parser->Value(), row_offset)); - row_offset += parser->Value().size; + bool next = parser_->Next(); + batch_.reset(new FileAdapterBatch(&parser_->Value(), row_offset_)); + row_offset_ += parser_->Value().size; return next; } // Indicates a number of rows/columns must be inferred @@ -481,9 +480,9 @@ class FileAdapter : dmlc::DataIter { size_t NumColumns() const { return kAdapterUnknownSize; } private: - size_t row_offset{0}; - std::unique_ptr batch; - dmlc::Parser* parser; + size_t row_offset_{0}; + std::unique_ptr batch_; + dmlc::Parser* parser_; }; class DMatrixSliceAdapterBatch { @@ -510,16 +509,16 @@ class DMatrixSliceAdapterBatch { class Line { public: Line(const SparsePage::Inst& inst, size_t row_idx) - : inst(inst), row_idx(row_idx) {} + : inst_(inst), row_idx_(row_idx) {} - size_t Size() { return inst.size(); } + size_t Size() { return inst_.size(); } COOTuple GetElement(size_t idx) { - return COOTuple(row_idx, inst[idx].index, inst[idx].fvalue); + return COOTuple{row_idx_, inst_[idx].index, inst_[idx].fvalue}; } private: - SparsePage::Inst inst; - size_t row_idx; + SparsePage::Inst inst_; + size_t row_idx_; }; Line GetLine(size_t idx) const { return Line(batch[ridx_set[idx]], idx); } const float* Labels() const { @@ -559,17 +558,18 @@ class DMatrixSliceAdapter : public detail::SingleBatchDataIter { public: DMatrixSliceAdapter(DMatrix* dmat, common::Span ridx_set) - : dmat(dmat), - ridx_set(ridx_set), - batch(*dmat->GetBatches().begin(), dmat, ridx_set) {} - const DMatrixSliceAdapterBatch& Value() const override { return batch; } + : dmat_(dmat), + ridx_set_(ridx_set), + batch_(*dmat_->GetBatches().begin(), dmat_, ridx_set) {} + const DMatrixSliceAdapterBatch& Value() const override { return batch_; } // Indicates a number of rows/columns must be inferred - size_t NumRows() const { return ridx_set.size(); } - size_t NumColumns() const { return dmat->Info().num_col_; } - DMatrix* dmat; - DMatrixSliceAdapterBatch batch; - bool before_first{true}; - common::Span ridx_set; + size_t NumRows() const { return ridx_set_.size(); } + size_t NumColumns() const { return dmat_->Info().num_col_; } + + private: + DMatrix* dmat_; + common::Span ridx_set_; + DMatrixSliceAdapterBatch batch_; }; }; // namespace data } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 4eb7cb68a..08c721136 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -46,7 +46,6 @@ TEST(ParallelGHistBuilder, Reset) { hist_builder.Reset(nthreads, kNodes, space, target_hist); common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) { - const size_t itask = r.begin(); const size_t tid = omp_get_thread_num(); GHistRow hist = hist_builder.GetInitializedHist(tid, inode); @@ -65,7 +64,6 @@ TEST(ParallelGHistBuilder, Reset) { hist_builder.Reset(nthreads, kNodesExtended, space2, target_hist); common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) { - const size_t itask = r.begin(); const size_t tid = omp_get_thread_num(); GHistRow hist = hist_builder.GetInitializedHist(tid, inode); @@ -80,7 +78,6 @@ TEST(ParallelGHistBuilder, Reset) { TEST(ParallelGHistBuilder, ReduceHist) { constexpr size_t kBins = 10; constexpr size_t kNodes = 5; - constexpr size_t kNodesExtended = 10; constexpr size_t kTasksPerNode = 10; constexpr double kValue = 1.0; const size_t nthreads = GetNThreads(); @@ -104,7 +101,6 @@ TEST(ParallelGHistBuilder, ReduceHist) { // Simple analog of BuildHist function, works in parallel for both tree-nodes and data in node common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) { - const size_t itask = r.begin(); const size_t tid = omp_get_thread_num(); GHistRow hist = hist_builder.GetInitializedHist(tid, inode); diff --git a/tests/cpp/common/test_threading_utils.cc b/tests/cpp/common/test_threading_utils.cc index 0d49c490a..392e4ba98 100755 --- a/tests/cpp/common/test_threading_utils.cc +++ b/tests/cpp/common/test_threading_utils.cc @@ -1,82 +1,83 @@ -#include - -#include "../../../src/common/column_matrix.h" -#include "../../../src/common/threading_utils.h" - -namespace xgboost { -namespace common { - -TEST(CreateBlockedSpace2d, Test) { - constexpr size_t kDim1 = 5; - constexpr size_t kDim2 = 3; - constexpr size_t kGrainSize = 1; - - BlockedSpace2d space(kDim1, [&](size_t i) { - return kDim2; - }, kGrainSize); - - ASSERT_EQ(kDim1 * kDim2, space.Size()); - - for (auto i = 0; i < kDim1; i++) { - for (auto j = 0; j < kDim2; j++) { - ASSERT_EQ(space.GetFirstDimension(i*kDim2 + j), i); - ASSERT_EQ(j, space.GetRange(i*kDim2 + j).begin()); - ASSERT_EQ(j + kGrainSize, space.GetRange(i*kDim2 + j).end()); - } - } -} - -TEST(ParallelFor2d, Test) { - constexpr size_t kDim1 = 100; - constexpr size_t kDim2 = 15; - constexpr size_t kGrainSize = 2; - - // working space is matrix of size (kDim1 x kDim2) - std::vector matrix(kDim1 * kDim2, 0); - BlockedSpace2d space(kDim1, [&](size_t i) { - return kDim2; - }, kGrainSize); - - ParallelFor2d(space, 4, [&](size_t i, Range1d r) { - for (auto j = r.begin(); j < r.end(); ++j) { - matrix[i*kDim2 + j] += 1; - } - }); - - for (auto i = 0; i < kDim1 * kDim2; i++) { - ASSERT_EQ(matrix[i], 1); - } -} - -TEST(ParallelFor2dNonUniform, Test) { - constexpr size_t kDim1 = 5; - constexpr size_t kGrainSize = 256; - - // here are quite non-uniform distribution in space - // but ParallelFor2d should split them by blocks with max size = kGrainSize - // and process in balanced manner (optimal performance) - std::vector dim2 { 1024, 500, 255, 5, 10000 }; - BlockedSpace2d space(kDim1, [&](size_t i) { - return dim2[i]; - }, kGrainSize); - - std::vector> working_space(kDim1); - for (auto i = 0; i < kDim1; i++) { - working_space[i].resize(dim2[i], 0); - } - - ParallelFor2d(space, 4, [&](size_t i, Range1d r) { - for (auto j = r.begin(); j < r.end(); ++j) { - working_space[i][j] += 1; - } - }); - - for (auto i = 0; i < kDim1; i++) { - for (auto j = 0; j < dim2[i]; j++) { - ASSERT_EQ(working_space[i][j], 1); - } - } -} - -} // namespace common -} // namespace xgboost +#include +#include + +#include "../../../src/common/column_matrix.h" +#include "../../../src/common/threading_utils.h" + +namespace xgboost { +namespace common { + +TEST(CreateBlockedSpace2d, Test) { + constexpr size_t kDim1 = 5; + constexpr size_t kDim2 = 3; + constexpr size_t kGrainSize = 1; + + BlockedSpace2d space(kDim1, [&](size_t i) { + return kDim2; + }, kGrainSize); + + ASSERT_EQ(kDim1 * kDim2, space.Size()); + + for (size_t i = 0; i < kDim1; i++) { + for (size_t j = 0; j < kDim2; j++) { + ASSERT_EQ(space.GetFirstDimension(i*kDim2 + j), i); + ASSERT_EQ(j, space.GetRange(i*kDim2 + j).begin()); + ASSERT_EQ(j + kGrainSize, space.GetRange(i*kDim2 + j).end()); + } + } +} + +TEST(ParallelFor2d, Test) { + constexpr size_t kDim1 = 100; + constexpr size_t kDim2 = 15; + constexpr size_t kGrainSize = 2; + + // working space is matrix of size (kDim1 x kDim2) + std::vector matrix(kDim1 * kDim2, 0); + BlockedSpace2d space(kDim1, [&](size_t i) { + return kDim2; + }, kGrainSize); + + ParallelFor2d(space, 4, [&](size_t i, Range1d r) { + for (auto j = r.begin(); j < r.end(); ++j) { + matrix[i*kDim2 + j] += 1; + } + }); + + for (size_t i = 0; i < kDim1 * kDim2; i++) { + ASSERT_EQ(matrix[i], 1); + } +} + +TEST(ParallelFor2dNonUniform, Test) { + constexpr size_t kDim1 = 5; + constexpr size_t kGrainSize = 256; + + // here are quite non-uniform distribution in space + // but ParallelFor2d should split them by blocks with max size = kGrainSize + // and process in balanced manner (optimal performance) + std::vector dim2 { 1024, 500, 255, 5, 10000 }; + BlockedSpace2d space(kDim1, [&](size_t i) { + return dim2[i]; + }, kGrainSize); + + std::vector> working_space(kDim1); + for (size_t i = 0; i < kDim1; i++) { + working_space[i].resize(dim2[i], 0); + } + + ParallelFor2d(space, 4, [&](size_t i, Range1d r) { + for (auto j = r.begin(); j < r.end(); ++j) { + working_space[i][j] += 1; + } + }); + + for (size_t i = 0; i < kDim1; i++) { + for (size_t j = 0; j < dim2[i]; j++) { + ASSERT_EQ(working_space[i][j], 1); + } + } +} + +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/data/test_adapter.cc b/tests/cpp/data/test_adapter.cc index d785ab5d1..4c048ff97 100644 --- a/tests/cpp/data/test_adapter.cc +++ b/tests/cpp/data/test_adapter.cc @@ -7,7 +7,6 @@ #include "../helpers.h" using namespace xgboost; // NOLINT TEST(adapter, CSRAdapter) { - int m = 3; int n = 2; std::vector data = {1, 2, 3, 4, 5}; std::vector feature_idx = {0, 1, 0, 1, 1}; diff --git a/tests/cpp/data/test_simple_dmatrix.cc b/tests/cpp/data/test_simple_dmatrix.cc index 0524165cb..0a66af5a0 100644 --- a/tests/cpp/data/test_simple_dmatrix.cc +++ b/tests/cpp/data/test_simple_dmatrix.cc @@ -83,7 +83,7 @@ TEST(SimpleDMatrix, Empty) { CHECK_EQ(batch.Size(), 0); } - data::DenseAdapter dense_adapter(nullptr, 0, 0, 0); + data::DenseAdapter dense_adapter(nullptr, 0, 0); dmat = data::SimpleDMatrix(&dense_adapter, std::numeric_limits::quiet_NaN(), 1); CHECK_EQ(dmat.Info().num_nonzero_, 0); @@ -136,7 +136,7 @@ TEST(SimpleDMatrix, FromDense) { int m = 3; int n = 2; std::vector data = {1, 2, 3, 4, 5, 6}; - data::DenseAdapter adapter(data.data(), m, m * n, n); + data::DenseAdapter adapter(data.data(), m, n); data::SimpleDMatrix dmat(&adapter, std::numeric_limits::quiet_NaN(), -1); EXPECT_EQ(dmat.Info().num_col_, 2); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cc b/tests/cpp/data/test_sparse_page_dmatrix.cc index 0be1a9e4b..c64b04ce5 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cc +++ b/tests/cpp/data/test_sparse_page_dmatrix.cc @@ -106,7 +106,7 @@ TEST(SparsePageDMatrix, Empty) { EXPECT_EQ(batch.Size(), 0); } - data::DenseAdapter dense_adapter(nullptr, 0, 0, 0); + data::DenseAdapter dense_adapter(nullptr, 0, 0); data::SparsePageDMatrix dmat2(&dense_adapter, std::numeric_limits::quiet_NaN(), 1,tmp_file); EXPECT_EQ(dmat2.Info().num_nonzero_, 0); @@ -163,7 +163,7 @@ TEST(SparsePageDMatrix, FromDense) { int m = 3; int n = 2; std::vector data = {1, 2, 3, 4, 5, 6}; - data::DenseAdapter adapter(data.data(), m, m * n, n); + data::DenseAdapter adapter(data.data(), m, n); data::SparsePageDMatrix dmat( &adapter, std::numeric_limits::quiet_NaN(), 1, tmp_file); EXPECT_EQ(dmat.Info().num_col_, 2);