io part refactor
This commit is contained in:
@@ -33,6 +33,7 @@ class BoostLearner : public rabit::Serializable {
|
||||
silent= 0;
|
||||
prob_buffer_row = 1.0f;
|
||||
distributed_mode = 0;
|
||||
updater_mode = 0;
|
||||
pred_buffer_size = 0;
|
||||
seed_per_iteration = 0;
|
||||
seed = 0;
|
||||
@@ -95,6 +96,7 @@ class BoostLearner : public rabit::Serializable {
|
||||
utils::Error("%s is invalid value for dsplit, should be row or col", val);
|
||||
}
|
||||
}
|
||||
if (!strcmp(name, "updater_mode")) updater_mode = atoi(val);
|
||||
if (!strcmp(name, "prob_buffer_row")) {
|
||||
prob_buffer_row = static_cast<float>(atof(val));
|
||||
utils::Check(distributed_mode == 0,
|
||||
@@ -259,9 +261,17 @@ class BoostLearner : public rabit::Serializable {
|
||||
*/
|
||||
inline void CheckInit(DMatrix *p_train) {
|
||||
int ncol = static_cast<int>(p_train->info.info.num_col);
|
||||
std::vector<bool> enabled(ncol, true);
|
||||
std::vector<bool> enabled(ncol, true);
|
||||
// set max row per batch to limited value
|
||||
// in distributed mode, use safe choice otherwise
|
||||
size_t max_row_perbatch = std::numeric_limits<size_t>::max();
|
||||
if (updater_mode != 0 || distributed_mode == 2) {
|
||||
max_row_perbatch = 32UL << 10UL;
|
||||
}
|
||||
// initialize column access
|
||||
p_train->fmat()->InitColAccess(enabled, prob_buffer_row);
|
||||
p_train->fmat()->InitColAccess(enabled,
|
||||
prob_buffer_row,
|
||||
max_row_perbatch);
|
||||
const int kMagicPage = 0xffffab02;
|
||||
// check, if it is DMatrixPage, then use hist maker
|
||||
if (p_train->magic == kMagicPage) {
|
||||
@@ -480,6 +490,8 @@ class BoostLearner : public rabit::Serializable {
|
||||
int silent;
|
||||
// distributed learning mode, if any, 0:none, 1:col, 2:row
|
||||
int distributed_mode;
|
||||
// updater mode, 0:normal, reserved for internal test
|
||||
int updater_mode;
|
||||
// cached size of predict buffer
|
||||
size_t pred_buffer_size;
|
||||
// maximum buffred row value
|
||||
|
||||
Reference in New Issue
Block a user