Memory consumption fix for row-major adapters (#6779)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com> Co-authored-by: fis <jm.yuan@outlook.com>
This commit is contained in:
@@ -872,14 +872,20 @@ void SparsePage::Push(const SparsePage &batch) {
|
||||
|
||||
template <typename AdapterBatchT>
|
||||
uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread) {
|
||||
constexpr bool kIsRowMajor = AdapterBatchT::kIsRowMajor;
|
||||
// Allow threading only for row-major case as column-major requires O(nthread*batch_size) memory
|
||||
nthread = kIsRowMajor ? nthread : 1;
|
||||
// Set number of threads but keep old value so we can reset it after
|
||||
int nthread_original = common::OmpSetNumThreadsWithoutHT(&nthread);
|
||||
if (!kIsRowMajor) {
|
||||
CHECK_EQ(nthread, 1);
|
||||
}
|
||||
auto& offset_vec = offset.HostVector();
|
||||
auto& data_vec = data.HostVector();
|
||||
|
||||
size_t builder_base_row_offset = this->Size();
|
||||
common::ParallelGroupBuilder<
|
||||
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type>
|
||||
Entry, std::remove_reference<decltype(offset_vec)>::type::value_type, kIsRowMajor>
|
||||
builder(&offset_vec, &data_vec, builder_base_row_offset);
|
||||
// Estimate expected number of rows by using last element in batch
|
||||
// This is not required to be exact but prevents unnecessary resizing
|
||||
@@ -892,13 +898,15 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
|
||||
}
|
||||
}
|
||||
size_t batch_size = batch.Size();
|
||||
const size_t thread_size = batch_size / nthread;
|
||||
builder.InitBudget(expected_rows+1, nthread);
|
||||
expected_rows = kIsRowMajor ? batch_size : expected_rows;
|
||||
uint64_t max_columns = 0;
|
||||
if (batch_size == 0) {
|
||||
omp_set_num_threads(nthread_original);
|
||||
return max_columns;
|
||||
}
|
||||
const size_t thread_size = batch_size / nthread;
|
||||
|
||||
builder.InitBudget(expected_rows, nthread);
|
||||
std::vector<std::vector<uint64_t>> max_columns_vector(nthread);
|
||||
dmlc::OMPException exec;
|
||||
std::atomic<bool> valid{true};
|
||||
|
||||
Reference in New Issue
Block a user