io part refactor

This commit is contained in:
tqchen
2015-06-02 23:18:31 -07:00
parent e5dd894960
commit 2937f5eebc
7 changed files with 276 additions and 113 deletions

View File

@@ -33,6 +33,7 @@ class BoostLearner : public rabit::Serializable {
silent= 0;
prob_buffer_row = 1.0f;
distributed_mode = 0;
updater_mode = 0;
pred_buffer_size = 0;
seed_per_iteration = 0;
seed = 0;
@@ -95,6 +96,7 @@ class BoostLearner : public rabit::Serializable {
utils::Error("%s is invalid value for dsplit, should be row or col", val);
}
}
if (!strcmp(name, "updater_mode")) updater_mode = atoi(val);
if (!strcmp(name, "prob_buffer_row")) {
prob_buffer_row = static_cast<float>(atof(val));
utils::Check(distributed_mode == 0,
@@ -259,9 +261,17 @@ class BoostLearner : public rabit::Serializable {
*/
inline void CheckInit(DMatrix *p_train) {
int ncol = static_cast<int>(p_train->info.info.num_col);
std::vector<bool> enabled(ncol, true);
std::vector<bool> enabled(ncol, true);
// set max row per batch to limited value
// in distributed mode, use safe choice otherwise
size_t max_row_perbatch = std::numeric_limits<size_t>::max();
if (updater_mode != 0 || distributed_mode == 2) {
max_row_perbatch = 32UL << 10UL;
}
// initialize column access
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
p_train->fmat()->InitColAccess(enabled,
prob_buffer_row,
max_row_perbatch);
const int kMagicPage = 0xffffab02;
// check, if it is DMatrixPage, then use hist maker
if (p_train->magic == kMagicPage) {
@@ -480,6 +490,8 @@ class BoostLearner : public rabit::Serializable {
int silent;
// distributed learning mode, if any, 0:none, 1:col, 2:row
int distributed_mode;
// updater mode, 0:normal, reserved for internal test
int updater_mode;
// cached size of predict buffer
size_t pred_buffer_size;
// maximum buffred row value