Refactor linear modelling and add new coordinate descent updater (#3103)
* Refactor linear modelling and add new coordinate descent updater * Allow unsorted column iterator * Add prediction cacheing to gblinear
This commit is contained in:
@@ -54,16 +54,16 @@ dmlc::DataIter<ColBatch>* SimpleDMatrix::ColIterator(const std::vector<bst_uint>
|
||||
|
||||
void SimpleDMatrix::InitColAccess(const std::vector<bool> &enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch) {
|
||||
if (this->HaveColAccess()) return;
|
||||
|
||||
size_t max_row_perbatch, bool sorted) {
|
||||
if (this->HaveColAccess(sorted)) return;
|
||||
col_iter_.sorted = sorted;
|
||||
col_iter_.cpages_.clear();
|
||||
if (info().num_row < max_row_perbatch) {
|
||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||
this->MakeOneBatch(enabled, pkeep, page.get());
|
||||
this->MakeOneBatch(enabled, pkeep, page.get(), sorted);
|
||||
col_iter_.cpages_.push_back(std::move(page));
|
||||
} else {
|
||||
this->MakeManyBatch(enabled, pkeep, max_row_perbatch);
|
||||
this->MakeManyBatch(enabled, pkeep, max_row_perbatch, sorted);
|
||||
}
|
||||
// setup col-size
|
||||
col_size_.resize(info().num_col);
|
||||
@@ -77,9 +77,8 @@ void SimpleDMatrix::InitColAccess(const std::vector<bool> &enabled,
|
||||
}
|
||||
|
||||
// internal function to make one batch from row iter.
|
||||
void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
SparsePage *pcol) {
|
||||
void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled, float pkeep,
|
||||
SparsePage* pcol, bool sorted) {
|
||||
// clear rowset
|
||||
buffered_rowset_.clear();
|
||||
// bit map
|
||||
@@ -144,21 +143,24 @@ void SimpleDMatrix::MakeOneBatch(const std::vector<bool>& enabled,
|
||||
}
|
||||
|
||||
CHECK_EQ(pcol->Size(), info().num_col);
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
|
||||
if (sorted) {
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch) {
|
||||
size_t max_row_perbatch, bool sorted) {
|
||||
size_t btop = 0;
|
||||
std::bernoulli_distribution coin_flip(pkeep);
|
||||
auto& rnd = common::GlobalRandom();
|
||||
@@ -179,7 +181,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
||||
}
|
||||
if (tmp.Size() >= max_row_perbatch) {
|
||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get());
|
||||
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted);
|
||||
col_iter_.cpages_.push_back(std::move(page));
|
||||
btop = buffered_rowset_.size();
|
||||
tmp.Clear();
|
||||
@@ -189,7 +191,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
||||
|
||||
if (tmp.Size() != 0) {
|
||||
std::unique_ptr<SparsePage> page(new SparsePage());
|
||||
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get());
|
||||
this->MakeColPage(tmp.GetRowBatch(0), btop, enabled, page.get(), sorted);
|
||||
col_iter_.cpages_.push_back(std::move(page));
|
||||
}
|
||||
}
|
||||
@@ -198,7 +200,7 @@ void SimpleDMatrix::MakeManyBatch(const std::vector<bool>& enabled,
|
||||
void SimpleDMatrix::MakeColPage(const RowBatch& batch,
|
||||
size_t buffer_begin,
|
||||
const std::vector<bool>& enabled,
|
||||
SparsePage* pcol) {
|
||||
SparsePage* pcol, bool sorted) {
|
||||
const int nthread = std::min(omp_get_max_threads(), std::max(omp_get_num_procs() / 2 - 2, 1));
|
||||
pcol->Clear();
|
||||
common::ParallelGroupBuilder<SparseBatch::Entry>
|
||||
@@ -231,13 +233,15 @@ void SimpleDMatrix::MakeColPage(const RowBatch& batch,
|
||||
}
|
||||
CHECK_EQ(pcol->Size(), info().num_col);
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
if (sorted) {
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,8 +36,8 @@ class SimpleDMatrix : public DMatrix {
|
||||
return iter;
|
||||
}
|
||||
|
||||
bool HaveColAccess() const override {
|
||||
return col_size_.size() != 0;
|
||||
bool HaveColAccess(bool sorted) const override {
|
||||
return col_size_.size() != 0 && col_iter_.sorted == sorted;
|
||||
}
|
||||
|
||||
const RowSet& buffered_rowset() const override {
|
||||
@@ -59,7 +59,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
|
||||
void InitColAccess(const std::vector<bool>& enabled,
|
||||
float subsample,
|
||||
size_t max_row_perbatch) override;
|
||||
size_t max_row_perbatch, bool sorted) override;
|
||||
|
||||
bool SingleColBlock() const override;
|
||||
|
||||
@@ -67,7 +67,7 @@ class SimpleDMatrix : public DMatrix {
|
||||
// in-memory column batch iterator.
|
||||
struct ColBatchIter: dmlc::DataIter<ColBatch> {
|
||||
public:
|
||||
ColBatchIter() : data_ptr_(0) {}
|
||||
ColBatchIter() : data_ptr_(0), sorted(false) {}
|
||||
void BeforeFirst() override {
|
||||
data_ptr_ = 0;
|
||||
}
|
||||
@@ -89,6 +89,8 @@ class SimpleDMatrix : public DMatrix {
|
||||
size_t data_ptr_;
|
||||
// temporal space for batch
|
||||
ColBatch batch_;
|
||||
// Is column sorted?
|
||||
bool sorted;
|
||||
};
|
||||
|
||||
// source data pointer.
|
||||
@@ -103,16 +105,16 @@ class SimpleDMatrix : public DMatrix {
|
||||
// internal function to make one batch from row iter.
|
||||
void MakeOneBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
SparsePage *pcol);
|
||||
SparsePage *pcol, bool sorted);
|
||||
|
||||
void MakeManyBatch(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch);
|
||||
size_t max_row_perbatch, bool sorted);
|
||||
|
||||
void MakeColPage(const RowBatch& batch,
|
||||
size_t buffer_begin,
|
||||
const std::vector<bool>& enabled,
|
||||
SparsePage* pcol);
|
||||
SparsePage* pcol, bool sorted);
|
||||
};
|
||||
} // namespace data
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -145,10 +145,9 @@ bool SparsePageDMatrix::TryInitColData() {
|
||||
|
||||
void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
||||
float pkeep,
|
||||
size_t max_row_perbatch) {
|
||||
if (HaveColAccess()) return;
|
||||
size_t max_row_perbatch, bool sorted) {
|
||||
if (HaveColAccess(sorted)) return;
|
||||
if (TryInitColData()) return;
|
||||
|
||||
const MetaInfo& info = this->info();
|
||||
if (max_row_perbatch == std::numeric_limits<size_t>::max()) {
|
||||
max_row_perbatch = kMaxRowPerBatch;
|
||||
@@ -197,13 +196,15 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
||||
}
|
||||
CHECK_EQ(pcol->Size(), info.num_col);
|
||||
// sort columns
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
if (sorted) {
|
||||
bst_omp_uint ncol = static_cast<bst_omp_uint>(pcol->Size());
|
||||
#pragma omp parallel for schedule(dynamic, 1) num_threads(nthread)
|
||||
for (bst_omp_uint i = 0; i < ncol; ++i) {
|
||||
if (pcol->offset[i] < pcol->offset[i + 1]) {
|
||||
std::sort(dmlc::BeginPtr(pcol->data) + pcol->offset[i],
|
||||
dmlc::BeginPtr(pcol->data) + pcol->offset[i + 1],
|
||||
SparseBatch::Entry::CmpValue);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -291,6 +292,7 @@ void SparsePageDMatrix::InitColAccess(const std::vector<bool>& enabled,
|
||||
}
|
||||
// initialize column data
|
||||
CHECK(TryInitColData());
|
||||
col_iter_->sorted = sorted;
|
||||
}
|
||||
|
||||
} // namespace data
|
||||
|
||||
@@ -40,8 +40,8 @@ class SparsePageDMatrix : public DMatrix {
|
||||
return iter;
|
||||
}
|
||||
|
||||
bool HaveColAccess() const override {
|
||||
return col_iter_.get() != nullptr;
|
||||
bool HaveColAccess(bool sorted) const override {
|
||||
return col_iter_.get() != nullptr && col_iter_->sorted == sorted;
|
||||
}
|
||||
|
||||
const RowSet& buffered_rowset() const override {
|
||||
@@ -67,7 +67,7 @@ class SparsePageDMatrix : public DMatrix {
|
||||
|
||||
void InitColAccess(const std::vector<bool>& enabled,
|
||||
float subsample,
|
||||
size_t max_row_perbatch) override;
|
||||
size_t max_row_perbatch, bool sorted) override;
|
||||
|
||||
/*! \brief page size 256 MB */
|
||||
static const size_t kPageSize = 256UL << 20UL;
|
||||
@@ -87,6 +87,8 @@ class SparsePageDMatrix : public DMatrix {
|
||||
bool Next() override;
|
||||
// initialize the column iterator with the specified index set.
|
||||
void Init(const std::vector<bst_uint>& index_set, bool load_all);
|
||||
// If the column features are sorted
|
||||
bool sorted;
|
||||
|
||||
private:
|
||||
// the temp page.
|
||||
|
||||
Reference in New Issue
Block a user