add part_load col

This commit is contained in:
tqchen
2014-10-16 19:41:43 -07:00
parent f512f08437
commit 3f3c90c3c0
8 changed files with 66 additions and 12 deletions

View File

@@ -247,7 +247,7 @@ class FMatrixPage : public IFMatrix {
size_t nmiss = buffered_rowset_.size() - (col_ptr[cidx+1] - col_ptr[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
}
virtual void InitColAccess(float pkeep = 1.0f) {
virtual void InitColAccess(const std::vector<bool> &enabled, float pkeep = 1.0f) {
if (this->HaveColAccess()) return;
utils::Printf("start to initialize page col access\n");
if (this->LoadColData()) {

View File

@@ -48,9 +48,10 @@ class FMatrixS : public IFMatrix{
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
}
virtual void InitColAccess(float pkeep = 1.0f) {
virtual void InitColAccess(const std::vector<bool> &enabled,
float pkeep = 1.0f) {
if (this->HaveColAccess()) return;
this->InitColData(pkeep);
this->InitColData(pkeep, enabled);
}
/*!
* \brief get the row iterator associated with FMatrix
@@ -141,7 +142,7 @@ class FMatrixS : public IFMatrix{
* \brief intialize column data
* \param pkeep probability to keep a row
*/
inline void InitColData(float pkeep) {
inline void InitColData(float pkeep, const std::vector<bool> &enabled) {
buffered_rowset_.clear();
// note: this part of code is serial, todo, parallelize this transformer
utils::SparseCSRMBuilder<RowBatch::Entry> builder(col_ptr_, col_data_);
@@ -155,7 +156,9 @@ class FMatrixS : public IFMatrix{
buffered_rowset_.push_back(static_cast<bst_uint>(batch.base_rowid+i));
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index);
if (enabled[inst[j].index]){
builder.AddBudget(inst[j].index);
}
}
}
}
@@ -172,9 +175,11 @@ class FMatrixS : public IFMatrix{
++ktop;
RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) {
builder.PushElem(inst[j].index,
Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue));
if (enabled[inst[j].index]) {
builder.PushElem(inst[j].index,
Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue));
}
}
}
}