[METHOD], add tree method option to prefer faster algo
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
* \brief Implementation of learning algorithm.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/learner.h>
|
||||
#include <dmlc/io.h>
|
||||
#include <algorithm>
|
||||
@@ -69,6 +70,8 @@ struct LearnerTrainParam
|
||||
bool seed_per_iteration;
|
||||
// data split mode, can be row, col, or none.
|
||||
int dsplit;
|
||||
// tree construction method
|
||||
int tree_method;
|
||||
// internal test flag
|
||||
std::string test_flag;
|
||||
// maximum buffered row value
|
||||
@@ -87,6 +90,11 @@ struct LearnerTrainParam
|
||||
.add_enum("col", 1)
|
||||
.add_enum("row", 2)
|
||||
.describe("Data split mode for distributed trainig. ");
|
||||
DMLC_DECLARE_FIELD(tree_method).set_default(0)
|
||||
.add_enum("auto", 0)
|
||||
.add_enum("approx", 1)
|
||||
.add_enum("exact", 2)
|
||||
.describe("Choice of tree construction method.");
|
||||
DMLC_DECLARE_FIELD(test_flag).set_default("")
|
||||
.describe("Internal test flag");
|
||||
DMLC_DECLARE_FIELD(prob_buffer_row).set_default(1.0f).set_range(0.0f, 1.0f)
|
||||
@@ -349,21 +357,42 @@ class LearnerImpl : public Learner {
|
||||
// check if p_train is ready to used by training.
|
||||
// if not, initialize the column access.
|
||||
inline void LazyInitDMatrix(DMatrix *p_train) {
|
||||
if (p_train->HaveColAccess()) return;
|
||||
int ncol = static_cast<int>(p_train->info().num_col);
|
||||
std::vector<bool> enabled(ncol, true);
|
||||
// set max row per batch to limited value
|
||||
// in distributed mode, use safe choice otherwise
|
||||
size_t max_row_perbatch = tparam.max_row_perbatch;
|
||||
if (tparam.test_flag == "block" || tparam.dsplit == 2) {
|
||||
max_row_perbatch = std::min(
|
||||
static_cast<size_t>(32UL << 10UL), max_row_perbatch);
|
||||
if (!p_train->HaveColAccess()) {
|
||||
int ncol = static_cast<int>(p_train->info().num_col);
|
||||
std::vector<bool> enabled(ncol, true);
|
||||
// set max row per batch to limited value
|
||||
// in distributed mode, use safe choice otherwise
|
||||
size_t max_row_perbatch = tparam.max_row_perbatch;
|
||||
const size_t safe_max_row = static_cast<size_t>(32UL << 10UL);
|
||||
|
||||
if (tparam.tree_method == 0 &&
|
||||
p_train->info().num_row >= (4UL << 20UL)) {
|
||||
LOG(CONSOLE) << "Tree method is automatically selected to be \'approx\'"
|
||||
<< " for faster speed."
|
||||
<< " to use old behavior(exact greedy algorithm on single machine),"
|
||||
<< " set tree_method to \'exact\'";
|
||||
max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
|
||||
}
|
||||
|
||||
if (tparam.tree_method == 1) {
|
||||
LOG(CONSOLE) << "Tree method is selected to be \'approx\'";
|
||||
max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
|
||||
}
|
||||
|
||||
if (tparam.test_flag == "block" || tparam.dsplit == 2) {
|
||||
max_row_perbatch = std::min(max_row_perbatch, safe_max_row);
|
||||
}
|
||||
// initialize column access
|
||||
p_train->InitColAccess(enabled,
|
||||
tparam.prob_buffer_row,
|
||||
max_row_perbatch);
|
||||
}
|
||||
// initialize column access
|
||||
p_train->InitColAccess(enabled,
|
||||
tparam.prob_buffer_row,
|
||||
max_row_perbatch);
|
||||
|
||||
if (!p_train->SingleColBlock() && cfg_.count("updater") == 0) {
|
||||
if (tparam.tree_method == 2) {
|
||||
LOG(CONSOLE) << "tree method is set to be 'exact',"
|
||||
<< " but currently we are only able to proceed with approximate algorithm";
|
||||
}
|
||||
cfg_["updater"] = "grow_histmaker,prune";
|
||||
if (gbm_.get() != nullptr) {
|
||||
gbm_->Configure(cfg_.begin(), cfg_.end());
|
||||
|
||||
Reference in New Issue
Block a user