Memory consumption fix for row-major adapters (#6779)

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
ShvetsKS
2021-03-26 03:44:30 +03:00
committed by GitHub
parent 744c46995c
commit 8825670c9c
4 changed files with 78 additions and 31 deletions

View File

@@ -148,6 +148,7 @@ class CSRAdapterBatch : public detail::NoMetaInfo {
&values_[begin_offset]);
}
size_t Size() const { return num_rows_; }
static constexpr bool kIsRowMajor = true;
private:
const size_t* row_ptr_;
@@ -204,6 +205,7 @@ class DenseAdapterBatch : public detail::NoMetaInfo {
const Line GetLine(size_t idx) const {
return Line(values_ + idx * num_features_, num_features_, idx);
}
static constexpr bool kIsRowMajor = true;
private:
const float* values_;
@@ -320,6 +322,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
size = size == 0 ? 0 : size - 1;
return size;
}
static constexpr bool kIsRowMajor = true;
Line const GetLine(size_t idx) const {
auto begin_offset = indptr_.GetElement<size_t>(idx, 0);
@@ -405,6 +408,7 @@ class CSCAdapterBatch : public detail::NoMetaInfo {
return Line(idx, end_offset - begin_offset, &row_idx_[begin_offset],
&values_[begin_offset]);
}
static constexpr bool kIsRowMajor = false;
private:
const size_t* col_ptr_;
@@ -537,6 +541,7 @@ class DataTableAdapterBatch : public detail::NoMetaInfo {
const Line GetLine(size_t idx) const {
return Line(DTGetType(feature_stypes_[idx]), num_rows_, idx, data_[idx]);
}
static constexpr bool kIsRowMajor = false;
private:
void** data_;
@@ -600,6 +605,7 @@ class FileAdapterBatch {
const float* BaseMargin() const { return nullptr; }
size_t Size() const { return block_->size; }
static constexpr bool kIsRowMajor = true;
private:
const dmlc::RowBlock<uint32_t>* block_;

View File

@@ -872,14 +872,20 @@ void SparsePage::Push(const SparsePage &batch) {
template <typename AdapterBatchT>
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor;
// Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory
nthread = kIsRowMajor ? nthread : 1;
// Set number of threads but keep old value so we can reset it after
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
if (!kIsRowMajor) {
CHECK_EQ(nthread, 1);
}
auto& offset_vec = offset.HostVector();
auto& data_vec = data.HostVector();
size_t builder_base_row_offset = this->Size();
common::ParallelGroupBuilder<
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type>
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type, kIsRowMajor>
builder(&offset_vec, &data_vec, builder_base_row_offset);
// Estimate expected number of rows by using last element in batch
// This is not required to be exact but prevents unnecessary resizing
@@ -892,13 +898,15 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
}
size_t batch_size = batch.Size();
const size_t thread_size = batch_size / nthread;
builder.InitBudget(expected_rows+1, nthread);
expected_rows = kIsRowMajor ? batch_size : expected_rows;
uint64_t max_columns = 0;
if (batch_size == 0) {
omp_set_num_threads(nthread_original);
return max_columns;
}
const size_t thread_size = batch_size / nthread;
builder.InitBudget(expected_rows, nthread);
std::vector<std::vector<uint64_t>> max_columns_vector(nthread);
dmlc::OMPException exec;
std::atomic<bool> valid{true};

View File

@@ -91,9 +91,6 @@ BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param)
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
// Set number of threads but keep old value so we can reset it after
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
std::vector<uint64_t> qids;
uint64_t default_max = std::numeric_limits<uint64_t>::max();
uint64_t last_group_id = default_max;
@@ -184,7 +181,6 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
info_.num_row_ = adapter->NumRows();
}
info_.num_nonzero_ = data_vec.size();
omp_set_num_threads(nthread_original);
}
SimpleDMatrix::SimpleDMatrix(dmlc::Stream* in_stream) {