diff --git a/doc/parameter.rst b/doc/parameter.rst index d56ab130c..315df6e33 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -110,7 +110,7 @@ Parameters for Tree Booster * ``tree_method`` string [default= ``auto``] - The tree construction algorithm used in XGBoost. See description in the `reference paper `_. - - Distributed and external memory version only support ``tree_method=approx``. + - XGBoost supports ``hist`` and ``approx`` for distributed training and only support ``approx`` for external memory version. - Choices: ``auto``, ``exact``, ``approx``, ``hist``, ``gpu_exact``, ``gpu_hist`` - ``auto``: Use heuristic to choose the fastest method. @@ -152,7 +152,7 @@ Parameters for Tree Booster - ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed. - ``prune``: prunes the splits where loss < min_split_loss (or gamma). - - In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune``. + - In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune`` by default, and you can set ``tree_method`` as ``hist`` to use ``grow_histmaker``. * ``refresh_leaf`` [default=1] diff --git a/include/xgboost/build_config.h b/include/xgboost/build_config.h index 1e36dc808..b3a8bcdc3 100644 --- a/include/xgboost/build_config.h +++ b/include/xgboost/build_config.h @@ -1,20 +1,18 @@ /*! - * Copyright (c) 2018 by Contributors + * Copyright 2019 by Contributors * \file build_config.h - * \brief Fall-back logic for platform-specific feature detection. - * \author Hyunsu Philip Cho */ #ifndef XGBOOST_BUILD_CONFIG_H_ #define XGBOOST_BUILD_CONFIG_H_ /* default logic for software pre-fetching */ #if (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_AMD64))) || defined(__INTEL_COMPILER) - // Enable _mm_prefetch for Intel compiler and MSVC+x86 +// Enable _mm_prefetch for Intel compiler and MSVC+x86 #define XGBOOST_MM_PREFETCH_PRESENT #define XGBOOST_BUILTIN_PREFETCH_PRESENT #elif defined(__GNUC__) - // Enable __builtin_prefetch for GCC - #define XGBOOST_BUILTIN_PREFETCH_PRESENT +// Enable __builtin_prefetch for GCC +#define XGBOOST_BUILTIN_PREFETCH_PRESENT #endif #endif // XGBOOST_BUILD_CONFIG_H_ diff --git a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala index 4cef75865..d9375361b 100644 --- a/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala +++ b/jvm-packages/xgboost4j-example/src/main/scala/ml/dmlc/xgboost4j/scala/example/spark/SparkTraining.scala @@ -31,7 +31,6 @@ object SparkTraining { println("Usage: program input_path") sys.exit(1) } - val spark = SparkSession.builder().getOrCreate() val inputPath = args(0) val schema = new StructType(Array( diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7914fb817..2fe554064 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -263,8 +263,10 @@ object XGBoost extends Serializable { validateSparkSslConf(sparkContext) if (params.contains("tree_method")) { - require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" + - " for now") + require(params("tree_method") == "hist" || + params("tree_method") == "approx" || + params("tree_method") == "auto", "xgboost4j-spark only supports tree_method as 'hist'," + + " 'approx' and 'auto'") } if (params.contains("train_test_ratio")) { logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala index 04b2d700c..870208982 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/BoosterParams.scala @@ -50,10 +50,21 @@ private[spark] trait BoosterParams extends Params { * overfitting. [default=6] range: [1, Int.MaxValue] */ final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " + - "value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1) + "value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0) final def getMaxDepth: Int = $(maxDepth) + + /** + * Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set. + */ + final val maxLeaves = new IntParam(this, "maxLeaves", + "Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.", + (value: Int) => value >= 0) + + final def getMaxLeaves: Int = $(maxDepth) + + /** * minimum sum of instance weight(hessian) needed in a child. If the tree partition step results * in a leaf node with the sum of instance weight less than min_child_weight, then the building @@ -147,7 +158,9 @@ private[spark] trait BoosterParams extends Params { * growth policy for fast histogram algorithm */ final val growPolicy = new Param[String](this, "growPolicy", - "growth policy for fast histogram algorithm", + "Controls a way new nodes are added to the tree. Currently supported only if" + + " tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" + + " closest to the root. lossguide: split at nodes with highest loss change.", (value: String) => BoosterParams.supportedGrowthPolicies.contains(value)) final def getGrowPolicy: String = $(growPolicy) @@ -242,6 +255,22 @@ private[spark] trait BoosterParams extends Params { final def getTreeLimit: Int = $(treeLimit) + final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints", + doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " + + "decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " + + "padded ") + + final def getMonotoneConstraints: String = $(monotoneConstraints) + + final val interactionConstraints = new Param[String](this, + name = "interactionConstraints", + doc = "Constraints for interaction representing permitted interactions. The constraints" + + " must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," + + " where each inner list is a group of indices of features that are allowed to interact" + + " with each other. See tutorial for more information") + + final def getInteractionConstraints: String = $(interactionConstraints) + setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6, minChildWeight -> 1, maxDeltaStep -> 0, growPolicy -> "depthwise", maxBins -> 16, diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala index cd8b0b361..6d9c9b474 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/GeneralParams.scala @@ -231,10 +231,11 @@ private[spark] trait ParamMapFuncs extends Params { def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = { for ((paramName, paramValue) <- xgboostParams) { if ((paramName == "booster" && paramValue != "gbtree") || - (paramName == "updater" && paramValue != "grow_histmaker,prune")) { + (paramName == "updater" && (paramValue != "grow_histmaker,prune" || + paramValue != "hist"))) { throw new IllegalArgumentException(s"you specified $paramName as $paramValue," + s" XGBoost-Spark only supports gbtree as booster type" + - " and grow_histmaker,prune as the updater type") + " and grow_histmaker,prune or hist as the updater type") } val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName) params.find(_.name == name) match { diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index a6f01ebd2..865c03c4a 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -18,18 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark import java.nio.file.Files import java.util.concurrent.LinkedBlockingDeque -import ml.dmlc.xgboost4j.java.Rabit + import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} import ml.dmlc.xgboost4j.scala.DMatrix import ml.dmlc.xgboost4j.scala.rabit.RabitTracker import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.TaskContext import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql._ import org.scalatest.FunSuite import scala.util.Random +import ml.dmlc.xgboost4j.java.Rabit + class XGBoostGeneralSuite extends FunSuite with PerTest { test("test Rabit allreduce to validate Scala-implemented Rabit tracker") { @@ -108,66 +111,89 @@ class XGBoostGeneralSuite extends FunSuite with PerTest { assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } - - ignore("test with fast histo depthwise") { + test("test with fast histo with monotone_constraints") { val eval = new EvalError() val training = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) - val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1", + val paramMap = Map("eta" -> "1", + "max_depth" -> "6", "silent" -> "1", "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise", - "eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2)) - // TODO: histogram algorithm seems to be very very sensitive to worker number + "num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)") val model = new XGBoostClassifier(paramMap).fit(training) assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) } - ignore("test with fast histo lossguide") { + test("test with fast histo with interaction_constraints") { + val eval = new EvalError() + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "1", + "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise", + "num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]") + val model = new XGBoostClassifier(paramMap).fit(training) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) + } + + test("test with fast histo depthwise") { + val eval = new EvalError() + val training = buildDataFrame(Classification.train) + val testDM = new DMatrix(Classification.test.iterator) + val paramMap = Map("eta" -> "1", + "max_depth" -> "6", "silent" -> "1", + "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise", + "num_round" -> 5, "num_workers" -> numWorkers) + val model = new XGBoostClassifier(paramMap).fit(training) + assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1) + } + + test("test with fast histo lossguide") { val eval = new EvalError() val training = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1", "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide", - "max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5, - "num_workers" -> math.min(numWorkers, 2)) + "max_leaves" -> "8", "num_round" -> 5, + "num_workers" -> numWorkers) val model = new XGBoostClassifier(paramMap).fit(training) val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } - ignore("test with fast histo lossguide with max bin") { + test("test with fast histo lossguide with max bin") { val eval = new EvalError() val training = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16", - "eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2)) + "eval_metric" -> "error", "num_round" -> 5, "num_workers" -> numWorkers) val model = new XGBoostClassifier(paramMap).fit(training) val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } - ignore("test with fast histo depthwidth with max depth") { + test("test with fast histo depthwidth with max depth") { val eval = new EvalError() val training = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) - val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", + val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "0", "objective" -> "binary:logistic", "tree_method" -> "hist", - "grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2", - "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2)) + "grow_policy" -> "depthwise", "max_depth" -> "2", + "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers) val model = new XGBoostClassifier(paramMap).fit(training) val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) } - ignore("test with fast histo depthwidth with max depth and max bin") { + test("test with fast histo depthwidth with max depth and max bin") { val eval = new EvalError() val training = buildDataFrame(Classification.train) val testDM = new DMatrix(Classification.test.iterator) - val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0", + val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "0", "objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2", - "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2)) + "eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers) val model = new XGBoostClassifier(paramMap).fit(training) val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) assert(x < 0.1) diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java index 5ef2db049..41611edf6 100644 --- a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/BoosterImplTest.java @@ -382,11 +382,12 @@ public class BoosterImplTest { metrics, null, null, 0); for (int i = 0; i < metrics.length; i++) for (int j = 1; j < metrics[i].length; j++) { - TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]); + TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1] || + Math.abs(metrics[i][j] - metrics[i][j - 1]) < 0.1); } for (int i = 0; i < metrics.length; i++) for (int j = 0; j < metrics[i].length; j++) { - TestCase.assertTrue(metrics[i][j] >= threshold); + TestCase.assertTrue(metrics[i][j] >= threshold); } booster.dispose(); } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index a988d3baf..1cca8179b 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -83,9 +83,9 @@ void HistCutMatrix::Init summary_array[i].Reserve(max_num_bins * kFactor); summary_array[i].SetPrune(out, max_num_bins * kFactor); } + CHECK_EQ(summary_array.size(), in_sketchs->size()); size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor); sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); - this->min_val.resize(sketchs.size()); row_ptr.push_back(0); for (size_t fid = 0; fid < summary_array.size(); ++fid) { @@ -479,14 +479,14 @@ void GHistBuilder::BuildHist(const std::vector& gpair, #pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided) for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) { - const size_t istart = iblock*block_size; - const size_t iend = (((iblock+1)*block_size > size) ? size : istart + block_size); + const size_t istart = iblock * block_size; + const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size); - const size_t bin = 2*thread_init_[0]*nbins_; - memcpy(hist_data + istart, (data + bin + istart), sizeof(double)*(iend - istart)); + const size_t bin = 2 * thread_init_[0] * nbins_; + memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart)); for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) { - const size_t bin = 2*thread_init_[i_bin_part]*nbins_; + const size_t bin = 2 * thread_init_[i_bin_part] * nbins_; for (size_t i = istart; i < iend; i++) { hist_data[i] += data[bin + i]; } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 2e916a726..ff5542768 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -13,6 +13,7 @@ #include "row_set.h" #include "../tree/param.h" #include "./quantile.h" +#include "../include/rabit/rabit.h" namespace xgboost { @@ -43,6 +44,10 @@ struct GHistEntry { sum_hess += e.sum_hess; } + inline static void Reduce(GHistEntry& a, const GHistEntry& b) { // NOLINT(*) + a.Add(b); + } + /*! \brief set sum to be difference of two GHistEntry's */ inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) { sum_grad = a.sum_grad - b.sum_grad; @@ -166,7 +171,7 @@ class GHistIndexBlockMatrix { }; /*! - * \brief histogram of graident statistics for a single node. + * \brief histogram of gradient statistics for a single node. * Consists of multiple GHistEntry's, each entry showing total graident statistics * for that particular bin * Uses global bin id so as to represent all features simultaneously @@ -254,6 +259,10 @@ class GHistBuilder { // construct a histogram via subtraction trick void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent); + uint32_t GetNumBins() { + return nbins_; + } + private: /*! \brief number of threads for parallel computation */ size_t nthread_; diff --git a/src/common/random.h b/src/common/random.h index e972b0d4a..ecce04765 100644 --- a/src/common/random.h +++ b/src/common/random.h @@ -17,6 +17,8 @@ #include #include +#include "io.h" + namespace xgboost { namespace common { /*! diff --git a/src/learner.cc b/src/learner.cc index 16f0b73a5..0f015479a 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -598,8 +598,8 @@ class LearnerImpl : public Learner { } const TreeMethod current_tree_method = tparam_.tree_method; + if (rabit::IsDistributed()) { - /* Choose tree_method='approx' when distributed training is activated */ CHECK(tparam_.dsplit != DataSplitMode::kAuto) << "Precondition violated; dsplit cannot be 'auto' in distributed mode"; if (tparam_.dsplit == DataSplitMode::kCol) { @@ -614,14 +614,13 @@ class LearnerImpl : public Learner { "for distributed training."; break; case TreeMethod::kApprox: + case TreeMethod::kHist: // things are okay, do nothing break; case TreeMethod::kExact: - case TreeMethod::kHist: - LOG(WARNING) << "Tree method was set to be '" - << (current_tree_method == TreeMethod::kExact ? - "exact" : "hist") - << "', but only 'approx' is available for distributed " + LOG(CONSOLE) << "Tree method was set to be " + << "exact" + << "', but only 'approx' and 'hist' is available for distributed " "training. The `tree_method` parameter is now being " "changed to 'approx'"; break; @@ -633,7 +632,15 @@ class LearnerImpl : public Learner { LOG(FATAL) << "Unknown tree_method (" << static_cast(current_tree_method) << ") detected"; } - tparam_.tree_method = TreeMethod::kApprox; + if (current_tree_method != TreeMethod::kHist) { + LOG(CONSOLE) << "Tree method is automatically selected to be 'approx'" + " for distributed training."; + tparam_.tree_method = TreeMethod::kApprox; + } else { + LOG(CONSOLE) << "Tree method is specified to be 'hist'" + " for distributed training."; + tparam_.tree_method = TreeMethod::kHist; + } } else if (!p_train->SingleColBlock()) { /* Some tree methods are not available for external-memory DMatrix */ switch (current_tree_method) { diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index d0fdd8ac3..706034dcc 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -126,6 +126,7 @@ class HistMaker: public BaseMaker { virtual void Update(const std::vector &gpair, DMatrix *p_fmat, RegTree *p_tree) { + CHECK(param_.max_depth > 0) << "max_depth must be larger than 0"; this->InitData(gpair, *p_fmat, *p_tree); this->InitWorkSet(p_fmat, *p_tree, &fwork_set_); // mark root node as fresh. @@ -345,10 +346,7 @@ class CQHistMaker: public HistMaker { this->wspace_.Init(this->param_, 1); // if it is C++11, use lazy evaluation for Allreduce, // to gain speedup in recovery -#if __cplusplus >= 201103L - auto lazy_get_hist = [&]() -#endif - { + auto lazy_get_hist = [&]() { thread_hist_.resize(omp_get_max_threads()); // start accumulating statistics for (const auto &batch : p_fmat->GetSortedColumnBatches()) { @@ -371,22 +369,18 @@ class CQHistMaker: public HistMaker { for (size_t i = 0; i < this->qexpand_.size(); ++i) { const int nid = this->qexpand_[i]; const int wid = this->node2workindex_[nid]; - this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)] - .data[0] = node_stats_[nid]; + this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)] + .data[0] = node_stats_[nid]; } }; // sync the histogram // if it is C++11, use lazy evaluation for Allreduce -#if __cplusplus >= 201103L this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data), - this->wspace_.hset[0].data.size(), lazy_get_hist); -#else - this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data), - this->wspace_.hset[0].data.size()); -#endif + this->wspace_.hset[0].data.size(), lazy_get_hist); } + void ResetPositionAfterSplit(DMatrix *p_fmat, - const RegTree &tree) override { + const RegTree &tree) override { this->GetSplitSet(this->qexpand_, tree, &fsplit_set_); } void ResetPosAndPropose(const std::vector &gpair, diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 62eb57de0..1d205c364 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -156,12 +156,18 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat, const int cright = (*p_tree)[nid].RightChild(); hist_.AddHistRow(cleft); hist_.AddHistRow(cright); - if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { + if (rabit::IsDistributed()) { + // in distributed mode, we need to keep consistent across workers BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]); SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); } else { - BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]); - SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]); + if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { + BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]); + SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]); + } else { + BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]); + SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]); + } } time_build_hist += dmlc::GetTime() - tstart; @@ -617,23 +623,34 @@ void QuantileHistMaker::Builder::InitNewNode(int nid, { auto& stats = snode_[nid].stats; - if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) { - /* specialized code for dense data - For dense data (with no missing value), - the sum of gradient histogram is equal to snode[nid] */ - GHistRow hist = hist_[nid]; - const std::vector& row_ptr = gmat.cut.row_ptr; - - const uint32_t ibegin = row_ptr[fid_least_bins_]; - const uint32_t iend = row_ptr[fid_least_bins_ + 1]; - for (uint32_t i = ibegin; i < iend; ++i) { - const GHistEntry et = hist.begin[i]; - stats.Add(et.sum_grad, et.sum_hess); + GHistRow hist = hist_[nid]; + if (rabit::IsDistributed()) { + // in distributed mode, the node's stats should be calculated from histogram, otherwise, + // we will have wrong results in EnumerateSplit() + // here we take the last feature in cut + for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) { + stats.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess); } } else { - const RowSetCollection::Elem e = row_set_collection_[nid]; - for (const size_t* it = e.begin; it < e.end; ++it) { - stats.Add(gpair[*it]); + if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased || + rabit::IsDistributed()) { + /* specialized code for dense data + For dense data (with no missing value), + the sum of gradient histogram is equal to snode[nid] + GHistRow hist = hist_[nid];*/ + const std::vector& row_ptr = gmat.cut.row_ptr; + + const uint32_t ibegin = row_ptr[fid_least_bins_]; + const uint32_t iend = row_ptr[fid_least_bins_ + 1]; + for (uint32_t i = ibegin; i < iend; ++i) { + const GHistEntry et = hist.begin[i]; + stats.Add(et.sum_grad, et.sum_hess); + } + } else { + const RowSetCollection::Elem e = row_set_collection_[nid]; + for (const size_t* it = e.begin; it < e.end; ++it) { + stats.Add(gpair[*it]); + } } } } diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 75d551070..4e8a1f276 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -105,6 +105,7 @@ class QuantileHistMaker: public TreeUpdater { } else { hist_builder_.BuildHist(gpair, row_indices, gmat, hist); } + this->histred_.Allreduce(hist.begin, hist_builder_.GetNumBins()); } inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) { @@ -225,6 +226,8 @@ class QuantileHistMaker: public TreeUpdater { enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; + + rabit::Reducer histred_; }; std::unique_ptr builder_; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index a9ca17415..98926cb22 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -52,10 +52,7 @@ class TreeRefresher: public TreeUpdater { } // if it is C++11, use lazy evaluation for Allreduce, // to gain speedup in recovery -#if __cplusplus >= 201103L - auto lazy_get_stats = [&]() -#endif - { + auto lazy_get_stats = [&]() { const MetaInfo &info = p_fmat->Info(); // start accumulating statistics for (const auto &batch : p_fmat->GetRowBatches()) { @@ -86,11 +83,7 @@ class TreeRefresher: public TreeUpdater { } } }; -#if __cplusplus >= 201103L reducer_.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats); -#else - reducer_.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size()); -#endif // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size();