add part_load col

This commit is contained in:
tqchen 2014-10-16 19:41:43 -07:00
parent f512f08437
commit 3f3c90c3c0
8 changed files with 66 additions and 12 deletions

View File

@ -138,9 +138,10 @@ class IFMatrix {
virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) = 0; virtual utils::IIterator<ColBatch> *ColIterator(const std::vector<bst_uint> &fset) = 0;
/*! /*!
* \brief check if column access is supported, if not, initialize column access * \brief check if column access is supported, if not, initialize column access
* \param enabled whether certain feature should be included in column access
* \param subsample subsample ratio when generating column access * \param subsample subsample ratio when generating column access
*/ */
virtual void InitColAccess(float subsample) = 0; virtual void InitColAccess(const std::vector<bool> &enabled, float subsample) = 0;
// the following are column meta data, should be able to answer them fast // the following are column meta data, should be able to answer them fast
/*! \return whether column access is enabled */ /*! \return whether column access is enabled */
virtual bool HaveColAccess(void) const = 0; virtual bool HaveColAccess(void) const = 0;

View File

@ -247,7 +247,7 @@ class FMatrixPage : public IFMatrix {
size_t nmiss = buffered_rowset_.size() - (col_ptr[cidx+1] - col_ptr[cidx]); size_t nmiss = buffered_rowset_.size() - (col_ptr[cidx+1] - col_ptr[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size(); return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
} }
virtual void InitColAccess(float pkeep = 1.0f) { virtual void InitColAccess(const std::vector<bool> &enabled, float pkeep = 1.0f) {
if (this->HaveColAccess()) return; if (this->HaveColAccess()) return;
utils::Printf("start to initialize page col access\n"); utils::Printf("start to initialize page col access\n");
if (this->LoadColData()) { if (this->LoadColData()) {

View File

@ -48,9 +48,10 @@ class FMatrixS : public IFMatrix{
size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]); size_t nmiss = buffered_rowset_.size() - (col_ptr_[cidx+1] - col_ptr_[cidx]);
return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size(); return 1.0f - (static_cast<float>(nmiss)) / buffered_rowset_.size();
} }
virtual void InitColAccess(float pkeep = 1.0f) { virtual void InitColAccess(const std::vector<bool> &enabled,
float pkeep = 1.0f) {
if (this->HaveColAccess()) return; if (this->HaveColAccess()) return;
this->InitColData(pkeep); this->InitColData(pkeep, enabled);
} }
/*! /*!
* \brief get the row iterator associated with FMatrix * \brief get the row iterator associated with FMatrix
@ -141,7 +142,7 @@ class FMatrixS : public IFMatrix{
* \brief intialize column data * \brief intialize column data
* \param pkeep probability to keep a row * \param pkeep probability to keep a row
*/ */
inline void InitColData(float pkeep) { inline void InitColData(float pkeep, const std::vector<bool> &enabled) {
buffered_rowset_.clear(); buffered_rowset_.clear();
// note: this part of code is serial, todo, parallelize this transformer // note: this part of code is serial, todo, parallelize this transformer
utils::SparseCSRMBuilder<RowBatch::Entry> builder(col_ptr_, col_data_); utils::SparseCSRMBuilder<RowBatch::Entry> builder(col_ptr_, col_data_);
@ -155,7 +156,9 @@ class FMatrixS : public IFMatrix{
buffered_rowset_.push_back(static_cast<bst_uint>(batch.base_rowid+i)); buffered_rowset_.push_back(static_cast<bst_uint>(batch.base_rowid+i));
RowBatch::Inst inst = batch[i]; RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) { for (bst_uint j = 0; j < inst.length; ++j) {
builder.AddBudget(inst[j].index); if (enabled[inst[j].index]){
builder.AddBudget(inst[j].index);
}
} }
} }
} }
@ -172,9 +175,11 @@ class FMatrixS : public IFMatrix{
++ktop; ++ktop;
RowBatch::Inst inst = batch[i]; RowBatch::Inst inst = batch[i];
for (bst_uint j = 0; j < inst.length; ++j) { for (bst_uint j = 0; j < inst.length; ++j) {
builder.PushElem(inst[j].index, if (enabled[inst[j].index]) {
Entry((bst_uint)(batch.base_rowid+i), builder.PushElem(inst[j].index,
inst[j].fvalue)); Entry((bst_uint)(batch.base_rowid+i),
inst[j].fvalue));
}
} }
} }
} }

View File

@ -31,6 +31,7 @@ class BoostLearner {
name_gbm_ = "gbtree"; name_gbm_ = "gbtree";
silent= 0; silent= 0;
prob_buffer_row = 1.0f; prob_buffer_row = 1.0f;
part_load_col = 0;
} }
~BoostLearner(void) { ~BoostLearner(void) {
if (obj_ != NULL) delete obj_; if (obj_ != NULL) delete obj_;
@ -88,6 +89,7 @@ class BoostLearner {
this->SetParam(n.c_str(), val); this->SetParam(n.c_str(), val);
} }
if (!strcmp(name, "silent")) silent = atoi(val); if (!strcmp(name, "silent")) silent = atoi(val);
if (!strcmp(name, "part_load_col")) part_load_col = atoi(val);
if (!strcmp(name, "prob_buffer_row")) { if (!strcmp(name, "prob_buffer_row")) {
prob_buffer_row = static_cast<float>(atof(val)); prob_buffer_row = static_cast<float>(atof(val));
this->SetParam("updater", "grow_colmaker,refresh,prune"); this->SetParam("updater", "grow_colmaker,refresh,prune");
@ -164,8 +166,41 @@ class BoostLearner {
* if not intialize it * if not intialize it
* \param p_train pointer to the matrix used by training * \param p_train pointer to the matrix used by training
*/ */
inline void CheckInit(DMatrix *p_train) { inline void CheckInit(DMatrix *p_train) {
p_train->fmat()->InitColAccess(prob_buffer_row); int ncol = p_train->info.info.num_col;
std::vector<bool> enabled(ncol, true);
if (part_load_col != 0) {
std::vector<unsigned> col_index;
for (int i = 0; i < ncol; ++i) {
col_index.push_back(i);
}
random::Shuffle(col_index);
std::string s_model;
utils::MemoryBufferStream ms(&s_model);
utils::IStream &fs = ms;
if (sync::GetRank() == 0) {
fs.Write(col_index);
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
fs.Read(&col_index);
}
int nsize = sync::GetWorldSize();
int step = (ncol + nsize -1) / nsize;
int pid = sync::GetRank();
std::fill(enabled.begin(), enabled.end(), false);
int start = step * pid;
int end = std::min(step * (pid + 1), ncol);
utils::Printf("rank %d idset:", pid);
for (int i = start; i < end; ++i) {
enabled[col_index[i]] = true;
utils::Printf(" %u", col_index[i]);
}
utils::Printf("\n");
}
// initialize column access
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
} }
/*! /*!
* \brief update the model for one iteration * \brief update the model for one iteration
@ -316,6 +351,8 @@ class BoostLearner {
// data fields // data fields
// silent during training // silent during training
int silent; int silent;
// randomly load part of data
int part_load_col;
// maximum buffred row value // maximum buffred row value
float prob_buffer_row; float prob_buffer_row;
// evaluation set // evaluation set

View File

@ -21,6 +21,9 @@ enum ReduceOp {
/*! \brief get rank of current process */ /*! \brief get rank of current process */
int GetRank(void); int GetRank(void);
/*! \brief get total number of process */
int GetWorldSize(void);
/*! /*!
* \brief this is used to check if sync module is a true distributed implementation, or simply a dummpy * \brief this is used to check if sync module is a true distributed implementation, or simply a dummpy
*/ */

View File

@ -17,6 +17,10 @@ bool IsDistributed(void) {
return false; return false;
} }
int GetWorldSize(void) {
return 1;
}
template<> template<>
void AllReduce<uint32_t>(uint32_t *sendrecvbuf, int count, ReduceOp op) { void AllReduce<uint32_t>(uint32_t *sendrecvbuf, int count, ReduceOp op) {
} }

View File

@ -8,6 +8,10 @@ int GetRank(void) {
return MPI::COMM_WORLD.Get_rank(); return MPI::COMM_WORLD.Get_rank();
} }
int GetWorldSize(void) {
return MPI::COMM_WORLD.Get_size();
}
void Init(int argc, char *argv[]) { void Init(int argc, char *argv[]) {
MPI::Init(argc, argv); MPI::Init(argc, argv);
} }

View File

@ -160,7 +160,7 @@ class BoostLearnTask {
if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed); if (!silent) printf("boosting round %d, %lu sec elapsed\n", i, elapsed);
learner.UpdateOneIter(i, *data); learner.UpdateOneIter(i, *data);
std::string res = learner.EvalOneIter(i, devalall, eval_data_names); std::string res = learner.EvalOneIter(i, devalall, eval_data_names);
if (silent < 1) { if (silent < 2) {
fprintf(stderr, "%s\n", res.c_str()); fprintf(stderr, "%s\n", res.c_str());
} }
if (save_period != 0 && (i + 1) % save_period == 0) { if (save_period != 0 && (i + 1) % save_period == 0) {