From 2937f5eebcf7a7b6a873ab5576a5f5f66a9a71d4 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 2 Jun 2015 23:18:31 -0700 Subject: [PATCH] io part refactor --- src/data.h | 6 +- src/io/page_dmatrix-inl.hpp | 11 +- src/io/page_fmatrix-inl.hpp | 16 +- src/io/simple_dmatrix-inl.hpp | 43 +++++- src/io/simple_fmatrix-inl.hpp | 279 +++++++++++++++++++++++----------- src/io/sparse_batch_page.h | 18 ++- src/learner/learner-inl.hpp | 16 +- 7 files changed, 276 insertions(+), 113 deletions(-) diff --git a/src/data.h b/src/data.h index d1f5eb427..63dd2d78f 100644 --- a/src/data.h +++ b/src/data.h @@ -140,8 +140,12 @@ class IFMatrix { * \brief check if column access is supported, if not, initialize column access * \param enabled whether certain feature should be included in column access * \param subsample subsample ratio when generating column access + * \param max_row_perbatch auxilary information, maximum row used in each column batch + * this is a hint information that can be ignored by the implementation */ - virtual void InitColAccess(const std::vector &enabled, float subsample) = 0; + virtual void InitColAccess(const std::vector &enabled, + float subsample, + size_t max_row_perbatch) = 0; // the following are column meta data, should be able to answer them fast /*! \return whether column access is enabled */ virtual bool HaveColAccess(void) const = 0; diff --git a/src/io/page_dmatrix-inl.hpp b/src/io/page_dmatrix-inl.hpp index 8fb02e18e..79455d130 100644 --- a/src/io/page_dmatrix-inl.hpp +++ b/src/io/page_dmatrix-inl.hpp @@ -33,10 +33,7 @@ class ThreadRowPageIterator: public utils::IIterator { } virtual bool Next(void) { if (!itr.Next(page_)) return false; - out_.base_rowid = base_rowid_; - out_.ind_ptr = BeginPtr(page_->offset); - out_.data_ptr = BeginPtr(page_->data); - out_.size = page_->offset.size() - 1; + out_ = page_->GetRowBatch(base_rowid_); base_rowid_ += out_.size; return true; } @@ -198,8 +195,8 @@ class DMatrixPageBase : public DataMatrix { } /*! \brief magic number used to identify DMatrix */ static const int kMagic = TKMagic; - /*! \brief page size 64 MB */ - static const size_t kPageSize = 64UL << 20UL; + /*! \brief page size 32 MB */ + static const size_t kPageSize = 32UL << 20UL; protected: virtual void set_cache_file(const std::string &cache_file) = 0; @@ -236,7 +233,7 @@ class DMatrixPage : public DMatrixPageBase<0xffffab02> { class DMatrixHalfRAM : public DMatrixPageBase<0xffffab03> { public: DMatrixHalfRAM(void) { - fmat_ = new FMatrixS(iter_); + fmat_ = new FMatrixS(iter_, this->info); } virtual ~DMatrixHalfRAM(void) { delete fmat_; diff --git a/src/io/page_fmatrix-inl.hpp b/src/io/page_fmatrix-inl.hpp index 7d4cdb9cf..18f4c6dee 100644 --- a/src/io/page_fmatrix-inl.hpp +++ b/src/io/page_fmatrix-inl.hpp @@ -58,11 +58,13 @@ struct ColConvertFactory { return true; } inline void Setup(float pkeep, + size_t max_row_perbatch, size_t num_col, utils::IIterator *iter, std::vector *buffered_rowset, const std::vector *enabled) { pkeep_ = pkeep; + max_row_perbatch_ = max_row_perbatch; num_col_ = num_col; iter_ = iter; buffered_rowset_ = buffered_rowset; @@ -87,7 +89,8 @@ struct ColConvertFactory { tmp_.Push(batch[i]); } } - if (tmp_.MemCostBytes() >= kPageSize) { + if (tmp_.MemCostBytes() >= kPageSize || + tmp_.Size() >= max_row_perbatch_) { this->MakeColPage(tmp_, BeginPtr(*buffered_rowset_) + btop, *enabled_, val); return true; @@ -157,6 +160,8 @@ struct ColConvertFactory { } // probability of keep float pkeep_; + // maximum number of rows per batch + size_t max_row_perbatch_; // number of columns size_t num_col_; // row batch iterator @@ -208,10 +213,10 @@ class FMatrixPage : public IFMatrix { return 1.0f - (static_cast(nmiss)) / num_buffered_row_; } virtual void InitColAccess(const std::vector &enabled, - float pkeep = 1.0f) { + float pkeep, size_t max_row_perbatch) { if (this->HaveColAccess()) return; if (TryLoadColData()) return; - this->InitColData(enabled, pkeep); + this->InitColData(enabled, pkeep, max_row_perbatch); utils::Check(TryLoadColData(), "failed on creating col.blob"); } /*! @@ -282,7 +287,8 @@ class FMatrixPage : public IFMatrix { * \brief intialize column data * \param pkeep probability to keep a row */ - inline void InitColData(const std::vector &enabled, float pkeep) { + inline void InitColData(const std::vector &enabled, + float pkeep, size_t max_row_perbatch) { // clear rowset buffered_rowset_.clear(); col_size_.resize(info.num_col()); @@ -294,7 +300,7 @@ class FMatrixPage : public IFMatrix { size_t bytes_write = 0; utils::ThreadBuffer citer; citer.SetParam("buffer_size", "2"); - citer.get_factory().Setup(pkeep, info.num_col(), + citer.get_factory().Setup(pkeep, max_row_perbatch, info.num_col(), iter_, &buffered_rowset_, &enabled); citer.Init(); SparsePage *pcol; diff --git a/src/io/simple_dmatrix-inl.hpp b/src/io/simple_dmatrix-inl.hpp index 9b0addc1c..3876c21ad 100644 --- a/src/io/simple_dmatrix-inl.hpp +++ b/src/io/simple_dmatrix-inl.hpp @@ -28,7 +28,7 @@ class DMatrixSimple : public DataMatrix { public: // constructor DMatrixSimple(void) : DataMatrix(kMagic) { - fmat_ = new FMatrixS(new OneBatchIter(this)); + fmat_ = new FMatrixS(new OneBatchIter(this), this->info); this->Clear(); } // virtual destructor @@ -171,7 +171,7 @@ class DMatrixSimple : public DataMatrix { utils::Check(tmagic == kMagic, "\"%s\" invalid format, magic number mismatch", fname == NULL ? "" : fname); info.LoadBinary(fs); - FMatrixS::LoadBinary(fs, &row_ptr_, &row_data_); + LoadBinary(fs, &row_ptr_, &row_data_); fmat_->LoadColAccess(fs); if (!silent) { @@ -198,9 +198,8 @@ class DMatrixSimple : public DataMatrix { utils::FileStream fs(utils::FopenCheck(fname, "wb")); int tmagic = kMagic; fs.Write(&tmagic, sizeof(tmagic)); - info.SaveBinary(fs); - FMatrixS::SaveBinary(fs, row_ptr_, row_data_); + SaveBinary(fs, row_ptr_, row_data_); fmat_->SaveColAccess(fs); fs.Close(); @@ -251,6 +250,42 @@ class DMatrixSimple : public DataMatrix { static const int kMagic = 0xffffab01; protected: + /*! + * \brief save data to binary stream + * \param fo output stream + * \param ptr pointer data + * \param data data content + */ + inline static void SaveBinary(utils::IStream &fo, + const std::vector &ptr, + const std::vector &data) { + size_t nrow = ptr.size() - 1; + fo.Write(&nrow, sizeof(size_t)); + fo.Write(BeginPtr(ptr), ptr.size() * sizeof(size_t)); + if (data.size() != 0) { + fo.Write(BeginPtr(data), data.size() * sizeof(RowBatch::Entry)); + } + } + /*! + * \brief load data from binary stream + * \param fi input stream + * \param out_ptr pointer data + * \param out_data data content + */ + inline static void LoadBinary(utils::IStream &fi, + std::vector *out_ptr, + std::vector *out_data) { + size_t nrow; + utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format"); + out_ptr->resize(nrow + 1); + utils::Check(fi.Read(BeginPtr(*out_ptr), out_ptr->size() * sizeof(size_t)) != 0, + "invalid input file format"); + out_data->resize(out_ptr->back()); + if (out_data->size() != 0) { + utils::Assert(fi.Read(BeginPtr(*out_data), out_data->size() * sizeof(RowBatch::Entry)) != 0, + "invalid input file format"); + } + } // one batch iterator that return content in the matrix struct OneBatchIter: utils::IIterator { explicit OneBatchIter(DMatrixSimple *parent) diff --git a/src/io/simple_fmatrix-inl.hpp b/src/io/simple_fmatrix-inl.hpp index acf85297f..fc6aab8f9 100644 --- a/src/io/simple_fmatrix-inl.hpp +++ b/src/io/simple_fmatrix-inl.hpp @@ -1,15 +1,18 @@ -#ifndef XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP -#define XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP +#ifndef XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP_ +#define XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP_ /*! * \file simple_fmatrix-inl.hpp * \brief the input data structure for gradient boosting * \author Tianqi Chen */ +#include #include "../data.h" #include "../utils/utils.h" #include "../utils/random.h" #include "../utils/omp.h" +#include "../learner/dmatrix.h" #include "../utils/group_data.h" +#include "./sparse_batch_page.h" namespace xgboost { namespace io { @@ -20,21 +23,23 @@ class FMatrixS : public IFMatrix { public: typedef SparseBatch::Entry Entry; /*! \brief constructor */ - FMatrixS(utils::IIterator *iter) { + FMatrixS(utils::IIterator *iter, + const learner::MetaInfo &info) + : info_(info) { this->iter_ = iter; } // destructor virtual ~FMatrixS(void) { - if (iter_ != NULL) delete iter_; + if (iter_ != NULL) delete iter_; } /*! \return whether column access is enabled */ virtual bool HaveColAccess(void) const { - return col_ptr_.size() != 0; + return col_size_.size() != 0; } /*! \brief get number of colmuns */ virtual size_t NumCol(void) const { utils::Check(this->HaveColAccess(), "NumCol:need column access"); - return col_ptr_.size() - 1; + return col_size_.size() - 1; } /*! \brief get number of buffered rows */ virtual const std::vector &buffered_rowset(void) const { @@ -42,17 +47,17 @@ class FMatrixS : public IFMatrix { } /*! \brief get column size */ virtual size_t GetColSize(size_t cidx) const { - return col_ptr_[cidx+1] - col_ptr_[cidx]; + return col_size_[cidx]; } /*! \brief get column density */ virtual float GetColDensity(size_t cidx) const { - size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]); + size_t nmiss = buffered_rowset_.size() - col_size_[cidx]; return 1.0f - (static_cast(nmiss)) / buffered_rowset_.size(); } virtual void InitColAccess(const std::vector &enabled, - float pkeep = 1.0f) { + float pkeep, size_t max_row_perbatch) { if (this->HaveColAccess()) return; - this->InitColData(pkeep, enabled); + this->InitColData(enabled, pkeep, max_row_perbatch); } /*! * \brief get the row iterator associated with FMatrix @@ -70,7 +75,7 @@ class FMatrixS : public IFMatrix { for (size_t i = 0; i < ncol; ++i) { col_iter_.col_index_[i] = static_cast(i); } - col_iter_.SetBatch(col_ptr_, col_data_); + col_iter_.BeforeFirst(); return &col_iter_; } /*! @@ -82,7 +87,7 @@ class FMatrixS : public IFMatrix { for (size_t i = 0; i < fset.size(); ++i) { if (fset[i] < ncol) col_iter_.col_index_.push_back(fset[i]); } - col_iter_.SetBatch(col_ptr_, col_data_); + col_iter_.BeforeFirst(); return &col_iter_; } /*! @@ -90,64 +95,52 @@ class FMatrixS : public IFMatrix { * \param fo output stream to save to */ inline void SaveColAccess(utils::IStream &fo) const { - fo.Write(buffered_rowset_); - if (buffered_rowset_.size() != 0) { - SaveBinary(fo, col_ptr_, col_data_); - } + size_t n = 0; + fo.Write(&n, sizeof(n)); } /*! * \brief load column access data from stream * \param fo output stream to load from */ inline void LoadColAccess(utils::IStream &fi) { - utils::Check(fi.Read(&buffered_rowset_), "invalid input file format"); - if (buffered_rowset_.size() != 0) { - LoadBinary(fi, &col_ptr_, &col_data_); - } + // do nothing in load col access } - /*! - * \brief save data to binary stream - * \param fo output stream - * \param ptr pointer data - * \param data data content - */ - inline static void SaveBinary(utils::IStream &fo, - const std::vector &ptr, - const std::vector &data) { - size_t nrow = ptr.size() - 1; - fo.Write(&nrow, sizeof(size_t)); - fo.Write(BeginPtr(ptr), ptr.size() * sizeof(size_t)); - if (data.size() != 0) { - fo.Write(BeginPtr(data), data.size() * sizeof(RowBatch::Entry)); - } - } - /*! - * \brief load data from binary stream - * \param fi input stream - * \param out_ptr pointer data - * \param out_data data content - */ - inline static void LoadBinary(utils::IStream &fi, - std::vector *out_ptr, - std::vector *out_data) { - size_t nrow; - utils::Check(fi.Read(&nrow, sizeof(size_t)) != 0, "invalid input file format"); - out_ptr->resize(nrow + 1); - utils::Check(fi.Read(BeginPtr(*out_ptr), out_ptr->size() * sizeof(size_t)) != 0, - "invalid input file format"); - out_data->resize(out_ptr->back()); - if (out_data->size() != 0) { - utils::Assert(fi.Read(BeginPtr(*out_data), out_data->size() * sizeof(RowBatch::Entry)) != 0, - "invalid input file format"); - } - } - + protected: /*! * \brief intialize column data + * \param enabled the list of enabled columns * \param pkeep probability to keep a row + * \param max_row_perbatch maximum row per batch */ - inline void InitColData(float pkeep, const std::vector &enabled) { + inline void InitColData(const std::vector &enabled, + float pkeep, size_t max_row_perbatch) { + col_iter_.Clear(); + if (info_.num_row() < max_row_perbatch) { + SparsePage *page = new SparsePage(); + this->MakeOneBatch(enabled, pkeep, page); + col_iter_.cpages_.push_back(page); + } else { + this->MakeManyBatch(enabled, pkeep, max_row_perbatch); + } + // setup col-size + col_size_.resize(info_.num_col()); + std::fill(col_size_.begin(), col_size_.end(), 0); + for (size_t i = 0; i < col_iter_.cpages_.size(); ++i) { + SparsePage *pcol = col_iter_.cpages_[i]; + for (size_t j = 0; j < pcol->Size(); ++j) { + col_size_[j] += pcol->offset[j + 1] - pcol->offset[j]; + } + } + } + /*! + * \brief make column page from iterator + * \param pkeep probability to keep a row + * \param pcol the target column + */ + inline void MakeOneBatch(const std::vector &enabled, + float pkeep, + SparsePage *pcol) { // clear rowset buffered_rowset_.clear(); // bit map @@ -157,8 +150,9 @@ class FMatrixS : public IFMatrix { { nthread = omp_get_num_threads(); } - // build the column matrix in parallel - utils::ParallelGroupBuilder builder(&col_ptr_, &col_data_); + pcol->Clear(); + utils::ParallelGroupBuilder + builder(&pcol->offset, &pcol->data); builder.InitBudget(0, nthread); // start working iter_->BeforeFirst(); @@ -189,7 +183,7 @@ class FMatrixS : public IFMatrix { } } builder.InitStorage(); - + iter_->BeforeFirst(); while (iter_->Next()) { const RowBatch &batch = iter_->Value(); @@ -209,66 +203,167 @@ class FMatrixS : public IFMatrix { } } } + + utils::Assert(pcol->Size() == info_.num_col(), "inconsistent col data"); // sort columns - bst_omp_uint ncol = static_cast(this->NumCol()); - #pragma omp parallel for schedule(static) + bst_omp_uint ncol = static_cast(pcol->Size()); + #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) for (bst_omp_uint i = 0; i < ncol; ++i) { - if (col_ptr_[i] < col_ptr_[i + 1]) { - std::sort(BeginPtr(col_data_) + col_ptr_[i], - BeginPtr(col_data_) + col_ptr_[i + 1], Entry::CmpValue); + if (pcol->offset[i] < pcol->offset[i + 1]) { + std::sort(BeginPtr(pcol->data) + pcol->offset[i], + BeginPtr(pcol->data) + pcol->offset[i + 1], + SparseBatch::Entry::CmpValue); + } + } + } + + inline void MakeManyBatch(const std::vector &enabled, + float pkeep, size_t max_row_perbatch) { + size_t btop = 0; + buffered_rowset_.clear(); + // internal temp cache + SparsePage tmp; tmp.Clear(); + iter_->BeforeFirst(); + while (iter_->Next()) { + const RowBatch &batch = iter_->Value(); + for (size_t i = 0; i < batch.size; ++i) { + bst_uint ridx = static_cast(batch.base_rowid + i); + if (pkeep == 1.0f || random::SampleBinary(pkeep)) { + buffered_rowset_.push_back(ridx); + tmp.Push(batch[i]); + } + if (tmp.Size() >= max_row_perbatch) { + SparsePage *page = new SparsePage(); + this->MakeColPage(tmp.GetRowBatch(0), + BeginPtr(buffered_rowset_) + btop, + enabled, page); + col_iter_.cpages_.push_back(page); + btop = buffered_rowset_.size(); + tmp.Clear(); + } + } + } + if (tmp.Size() != 0) { + SparsePage *page = new SparsePage(); + this->MakeColPage(tmp.GetRowBatch(0), + BeginPtr(buffered_rowset_) + btop, + enabled, page); + col_iter_.cpages_.push_back(page); + } + } + // make column page from subset of rowbatchs + inline void MakeColPage(const RowBatch &batch, + const bst_uint *ridx, + const std::vector &enabled, + SparsePage *pcol) { + int nthread; + #pragma omp parallel + { + nthread = omp_get_num_threads(); + int max_nthread = std::max(omp_get_num_procs() / 2 - 2, 1); + if (nthread > max_nthread) { + nthread = max_nthread; + } + } + pcol->Clear(); + utils::ParallelGroupBuilder + builder(&pcol->offset, &pcol->data); + builder.InitBudget(info_.num_col(), nthread); + bst_omp_uint ndata = static_cast(batch.size); + #pragma omp parallel for schedule(static) num_threads(nthread) + for (bst_omp_uint i = 0; i < ndata; ++i) { + int tid = omp_get_thread_num(); + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + const SparseBatch::Entry &e = inst[j]; + if (enabled[e.index]) { + builder.AddBudget(e.index, tid); + } + } + } + builder.InitStorage(); + #pragma omp parallel for schedule(static) num_threads(nthread) + for (bst_omp_uint i = 0; i < ndata; ++i) { + int tid = omp_get_thread_num(); + RowBatch::Inst inst = batch[i]; + for (bst_uint j = 0; j < inst.length; ++j) { + const SparseBatch::Entry &e = inst[j]; + builder.Push(e.index, + SparseBatch::Entry(ridx[i], e.fvalue), + tid); + } + } + utils::Assert(pcol->Size() == info_.num_col(), "inconsistent col data"); + // sort columns + bst_omp_uint ncol = static_cast(pcol->Size()); + #pragma omp parallel for schedule(dynamic, 1) num_threads(nthread) + for (bst_omp_uint i = 0; i < ncol; ++i) { + if (pcol->offset[i] < pcol->offset[i + 1]) { + std::sort(BeginPtr(pcol->data) + pcol->offset[i], + BeginPtr(pcol->data) + pcol->offset[i + 1], + SparseBatch::Entry::CmpValue); } } } private: // one batch iterator that return content in the matrix - struct OneBatchIter: utils::IIterator { - OneBatchIter(void) : at_first_(true){} - virtual ~OneBatchIter(void) {} + struct ColBatchIter: utils::IIterator { + ColBatchIter(void) : data_ptr_(0) {} + virtual ~ColBatchIter(void) { + this->Clear(); + } virtual void BeforeFirst(void) { - at_first_ = true; + data_ptr_ = 0; } virtual bool Next(void) { - if (!at_first_) return false; - at_first_ = false; - return true; - } - virtual const ColBatch &Value(void) const { - return batch_; - } - inline void SetBatch(const std::vector &ptr, - const std::vector &data) { + if (data_ptr_ >= cpages_.size()) return false; + data_ptr_ += 1; + SparsePage *pcol = cpages_[data_ptr_ - 1]; batch_.size = col_index_.size(); col_data_.resize(col_index_.size(), SparseBatch::Inst(NULL, 0)); for (size_t i = 0; i < col_data_.size(); ++i) { const bst_uint ridx = col_index_[i]; - col_data_[i] = SparseBatch::Inst(&data[0] + ptr[ridx], - static_cast(ptr[ridx+1] - ptr[ridx])); + col_data_[i] = SparseBatch::Inst + (BeginPtr(pcol->data) + pcol->offset[ridx], + static_cast(pcol->offset[ridx + 1] - pcol->offset[ridx])); } batch_.col_index = BeginPtr(col_index_); - batch_.col_data = BeginPtr(col_data_); - this->BeforeFirst(); + batch_.col_data = BeginPtr(col_data_); + return true; + } + virtual const ColBatch &Value(void) const { + return batch_; + } + inline void Clear(void) { + for (size_t i = 0; i < cpages_.size(); ++i) { + delete cpages_[i]; + } + cpages_.clear(); } // data content std::vector col_index_; + // column content std::vector col_data_; - // whether is at first - bool at_first_; + // column sparse pages + std::vector cpages_; + // data pointer + size_t data_ptr_; // temporal space for batch ColBatch batch_; - }; + }; // --- data structure used to support InitColAccess -- // column iterator - OneBatchIter col_iter_; + ColBatchIter col_iter_; + // shared meta info with DMatrix + const learner::MetaInfo &info_; // row iterator utils::IIterator *iter_; /*! \brief list of row index that are buffered */ std::vector buffered_rowset_; - /*! \brief column pointer of CSC format */ - std::vector col_ptr_; - /*! \brief column datas in CSC format */ - std::vector col_data_; + // count for column data + std::vector col_size_; }; } // namespace io } // namespace xgboost -#endif // XGBOOST_IO_SIMPLE_FMATRIX_INL_HPP +#endif // XGBOOST_IO_SLICE_FMATRIX_INL_HPP diff --git a/src/io/sparse_batch_page.h b/src/io/sparse_batch_page.h index 319f9da5c..d94141a6e 100644 --- a/src/io/sparse_batch_page.h +++ b/src/io/sparse_batch_page.h @@ -178,8 +178,22 @@ class SparsePage { offset.push_back(offset.back() + inst.length); size_t begin = data.size(); data.resize(begin + inst.length); - std::memcpy(BeginPtr(data) + begin, inst.data, - sizeof(SparseBatch::Entry) * inst.length); + if (inst.length != 0) { + std::memcpy(BeginPtr(data) + begin, inst.data, + sizeof(SparseBatch::Entry) * inst.length); + } + } + /*! + * \param base_rowid base_rowid of the data + * \return row batch representation of the page + */ + inline RowBatch GetRowBatch(size_t base_rowid) const { + RowBatch out; + out.base_rowid = base_rowid; + out.ind_ptr = BeginPtr(offset); + out.data_ptr = BeginPtr(data); + out.size = offset.size() - 1; + return out; } private: diff --git a/src/learner/learner-inl.hpp b/src/learner/learner-inl.hpp index 5a080d5b1..45e312aa7 100644 --- a/src/learner/learner-inl.hpp +++ b/src/learner/learner-inl.hpp @@ -33,6 +33,7 @@ class BoostLearner : public rabit::Serializable { silent= 0; prob_buffer_row = 1.0f; distributed_mode = 0; + updater_mode = 0; pred_buffer_size = 0; seed_per_iteration = 0; seed = 0; @@ -95,6 +96,7 @@ class BoostLearner : public rabit::Serializable { utils::Error("%s is invalid value for dsplit, should be row or col", val); } } + if (!strcmp(name, "updater_mode")) updater_mode = atoi(val); if (!strcmp(name, "prob_buffer_row")) { prob_buffer_row = static_cast(atof(val)); utils::Check(distributed_mode == 0, @@ -259,9 +261,17 @@ class BoostLearner : public rabit::Serializable { */ inline void CheckInit(DMatrix *p_train) { int ncol = static_cast(p_train->info.info.num_col); - std::vector enabled(ncol, true); + std::vector enabled(ncol, true); + // set max row per batch to limited value + // in distributed mode, use safe choice otherwise + size_t max_row_perbatch = std::numeric_limits::max(); + if (updater_mode != 0 || distributed_mode == 2) { + max_row_perbatch = 32UL << 10UL; + } // initialize column access - p_train->fmat()->InitColAccess(enabled, prob_buffer_row); + p_train->fmat()->InitColAccess(enabled, + prob_buffer_row, + max_row_perbatch); const int kMagicPage = 0xffffab02; // check, if it is DMatrixPage, then use hist maker if (p_train->magic == kMagicPage) { @@ -480,6 +490,8 @@ class BoostLearner : public rabit::Serializable { int silent; // distributed learning mode, if any, 0:none, 1:col, 2:row int distributed_mode; + // updater mode, 0:normal, reserved for internal test + int updater_mode; // cached size of predict buffer size_t pred_buffer_size; // maximum buffred row value