add part_load col

This commit is contained in:
tqchen
2014-10-16 19:41:43 -07:00
parent f512f08437
commit 3f3c90c3c0
8 changed files with 66 additions and 12 deletions

View File

@@ -31,6 +31,7 @@ class BoostLearner {
name_gbm_ = "gbtree";
silent= 0;
prob_buffer_row = 1.0f;
part_load_col = 0;
}
~BoostLearner(void) {
if (obj_ != NULL) delete obj_;
@@ -88,6 +89,7 @@ class BoostLearner {
this->SetParam(n.c_str(), val);
}
if (!strcmp(name, "silent")) silent = atoi(val);
if (!strcmp(name, "part_load_col")) part_load_col = atoi(val);
if (!strcmp(name, "prob_buffer_row")) {
prob_buffer_row = static_cast<float>(atof(val));
this->SetParam("updater", "grow_colmaker,refresh,prune");
@@ -164,8 +166,41 @@ class BoostLearner {
* if not intialize it
* \param p_train pointer to the matrix used by training
*/
inline void CheckInit(DMatrix *p_train) {
p_train->fmat()->InitColAccess(prob_buffer_row);
inline void CheckInit(DMatrix *p_train) {
int ncol = p_train->info.info.num_col;
std::vector<bool> enabled(ncol, true);
if (part_load_col != 0) {
std::vector<unsigned> col_index;
for (int i = 0; i < ncol; ++i) {
col_index.push_back(i);
}
random::Shuffle(col_index);
std::string s_model;
utils::MemoryBufferStream ms(&s_model);
utils::IStream &fs = ms;
if (sync::GetRank() == 0) {
fs.Write(col_index);
sync::Bcast(&s_model, 0);
} else {
sync::Bcast(&s_model, 0);
fs.Read(&col_index);
}
int nsize = sync::GetWorldSize();
int step = (ncol + nsize -1) / nsize;
int pid = sync::GetRank();
std::fill(enabled.begin(), enabled.end(), false);
int start = step * pid;
int end = std::min(step * (pid + 1), ncol);
utils::Printf("rank %d idset:", pid);
for (int i = start; i < end; ++i) {
enabled[col_index[i]] = true;
utils::Printf(" %u", col_index[i]);
}
utils::Printf("\n");
}
// initialize column access
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
}
/*!
* \brief update the model for one iteration
@@ -316,6 +351,8 @@ class BoostLearner {
// data fields
// silent during training
int silent;
// randomly load part of data
int part_load_col;
// maximum buffred row value
float prob_buffer_row;
// evaluation set