Distributed Fast Histogram Algorithm (#4011)
* add back train method but mark as deprecated * add back train method but mark as deprecated * add back train method but mark as deprecated * fix scalastyle error * fix scalastyle error * fix scalastyle error * fix scalastyle error * init * allow hist algo * more changes * temp * update * remove hist sync * udpate rabit * change hist size * change the histogram * update kfactor * sync per node stats * temp * update * final * code clean * update rabit * more cleanup * fix errors * fix failed tests * enforce c++11 * fix lint issue * broadcast subsampled feature correctly * revert some changes * fix lint issue * enable monotone and interaction constraints * don't specify default for monotone and interactions * update docs
This commit is contained in:
parent
8905df4a18
commit
ae3bb9c2d5
@ -110,7 +110,7 @@ Parameters for Tree Booster
|
||||
* ``tree_method`` string [default= ``auto``]
|
||||
|
||||
- The tree construction algorithm used in XGBoost. See description in the `reference paper <http://arxiv.org/abs/1603.02754>`_.
|
||||
- Distributed and external memory version only support ``tree_method=approx``.
|
||||
- XGBoost supports ``hist`` and ``approx`` for distributed training and only support ``approx`` for external memory version.
|
||||
- Choices: ``auto``, ``exact``, ``approx``, ``hist``, ``gpu_exact``, ``gpu_hist``
|
||||
|
||||
- ``auto``: Use heuristic to choose the fastest method.
|
||||
@ -152,7 +152,7 @@ Parameters for Tree Booster
|
||||
- ``refresh``: refreshes tree's statistics and/or leaf values based on the current data. Note that no random subsampling of data rows is performed.
|
||||
- ``prune``: prunes the splits where loss < min_split_loss (or gamma).
|
||||
|
||||
- In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune``.
|
||||
- In a distributed setting, the implicit updater sequence value would be adjusted to ``grow_histmaker,prune`` by default, and you can set ``tree_method`` as ``hist`` to use ``grow_histmaker``.
|
||||
|
||||
* ``refresh_leaf`` [default=1]
|
||||
|
||||
|
||||
@ -1,20 +1,18 @@
|
||||
/*!
|
||||
* Copyright (c) 2018 by Contributors
|
||||
* Copyright 2019 by Contributors
|
||||
* \file build_config.h
|
||||
* \brief Fall-back logic for platform-specific feature detection.
|
||||
* \author Hyunsu Philip Cho
|
||||
*/
|
||||
#ifndef XGBOOST_BUILD_CONFIG_H_
|
||||
#define XGBOOST_BUILD_CONFIG_H_
|
||||
|
||||
/* default logic for software pre-fetching */
|
||||
#if (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_AMD64))) || defined(__INTEL_COMPILER)
|
||||
// Enable _mm_prefetch for Intel compiler and MSVC+x86
|
||||
// Enable _mm_prefetch for Intel compiler and MSVC+x86
|
||||
#define XGBOOST_MM_PREFETCH_PRESENT
|
||||
#define XGBOOST_BUILTIN_PREFETCH_PRESENT
|
||||
#elif defined(__GNUC__)
|
||||
// Enable __builtin_prefetch for GCC
|
||||
#define XGBOOST_BUILTIN_PREFETCH_PRESENT
|
||||
// Enable __builtin_prefetch for GCC
|
||||
#define XGBOOST_BUILTIN_PREFETCH_PRESENT
|
||||
#endif
|
||||
|
||||
#endif // XGBOOST_BUILD_CONFIG_H_
|
||||
|
||||
@ -31,7 +31,6 @@ object SparkTraining {
|
||||
println("Usage: program input_path")
|
||||
sys.exit(1)
|
||||
}
|
||||
|
||||
val spark = SparkSession.builder().getOrCreate()
|
||||
val inputPath = args(0)
|
||||
val schema = new StructType(Array(
|
||||
|
||||
@ -263,8 +263,10 @@ object XGBoost extends Serializable {
|
||||
validateSparkSslConf(sparkContext)
|
||||
|
||||
if (params.contains("tree_method")) {
|
||||
require(params("tree_method") != "hist", "xgboost4j-spark does not support fast histogram" +
|
||||
" for now")
|
||||
require(params("tree_method") == "hist" ||
|
||||
params("tree_method") == "approx" ||
|
||||
params("tree_method") == "auto", "xgboost4j-spark only supports tree_method as 'hist'," +
|
||||
" 'approx' and 'auto'")
|
||||
}
|
||||
if (params.contains("train_test_ratio")) {
|
||||
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
|
||||
|
||||
@ -50,10 +50,21 @@ private[spark] trait BoosterParams extends Params {
|
||||
* overfitting. [default=6] range: [1, Int.MaxValue]
|
||||
*/
|
||||
final val maxDepth = new IntParam(this, "maxDepth", "maximum depth of a tree, increase this " +
|
||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 1)
|
||||
"value will make model more complex/likely to be overfitting.", (value: Int) => value >= 0)
|
||||
|
||||
final def getMaxDepth: Int = $(maxDepth)
|
||||
|
||||
|
||||
/**
|
||||
* Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.
|
||||
*/
|
||||
final val maxLeaves = new IntParam(this, "maxLeaves",
|
||||
"Maximum number of nodes to be added. Only relevant when grow_policy=lossguide is set.",
|
||||
(value: Int) => value >= 0)
|
||||
|
||||
final def getMaxLeaves: Int = $(maxDepth)
|
||||
|
||||
|
||||
/**
|
||||
* minimum sum of instance weight(hessian) needed in a child. If the tree partition step results
|
||||
* in a leaf node with the sum of instance weight less than min_child_weight, then the building
|
||||
@ -147,7 +158,9 @@ private[spark] trait BoosterParams extends Params {
|
||||
* growth policy for fast histogram algorithm
|
||||
*/
|
||||
final val growPolicy = new Param[String](this, "growPolicy",
|
||||
"growth policy for fast histogram algorithm",
|
||||
"Controls a way new nodes are added to the tree. Currently supported only if" +
|
||||
" tree_method is set to hist. Choices: depthwise, lossguide. depthwise: split at nodes" +
|
||||
" closest to the root. lossguide: split at nodes with highest loss change.",
|
||||
(value: String) => BoosterParams.supportedGrowthPolicies.contains(value))
|
||||
|
||||
final def getGrowPolicy: String = $(growPolicy)
|
||||
@ -242,6 +255,22 @@ private[spark] trait BoosterParams extends Params {
|
||||
|
||||
final def getTreeLimit: Int = $(treeLimit)
|
||||
|
||||
final val monotoneConstraints = new Param[String](this, name = "monotoneConstraints",
|
||||
doc = "a list in length of number of features, 1 indicate monotonic increasing, - 1 means " +
|
||||
"decreasing, 0 means no constraint. If it is shorter than number of features, 0 will be " +
|
||||
"padded ")
|
||||
|
||||
final def getMonotoneConstraints: String = $(monotoneConstraints)
|
||||
|
||||
final val interactionConstraints = new Param[String](this,
|
||||
name = "interactionConstraints",
|
||||
doc = "Constraints for interaction representing permitted interactions. The constraints" +
|
||||
" must be specified in the form of a nest list, e.g. [[0, 1], [2, 3, 4]]," +
|
||||
" where each inner list is a group of indices of features that are allowed to interact" +
|
||||
" with each other. See tutorial for more information")
|
||||
|
||||
final def getInteractionConstraints: String = $(interactionConstraints)
|
||||
|
||||
setDefault(eta -> 0.3, gamma -> 0, maxDepth -> 6,
|
||||
minChildWeight -> 1, maxDeltaStep -> 0,
|
||||
growPolicy -> "depthwise", maxBins -> 16,
|
||||
|
||||
@ -231,10 +231,11 @@ private[spark] trait ParamMapFuncs extends Params {
|
||||
def XGBoostToMLlibParams(xgboostParams: Map[String, Any]): Unit = {
|
||||
for ((paramName, paramValue) <- xgboostParams) {
|
||||
if ((paramName == "booster" && paramValue != "gbtree") ||
|
||||
(paramName == "updater" && paramValue != "grow_histmaker,prune")) {
|
||||
(paramName == "updater" && (paramValue != "grow_histmaker,prune" ||
|
||||
paramValue != "hist"))) {
|
||||
throw new IllegalArgumentException(s"you specified $paramName as $paramValue," +
|
||||
s" XGBoost-Spark only supports gbtree as booster type" +
|
||||
" and grow_histmaker,prune as the updater type")
|
||||
" and grow_histmaker,prune or hist as the updater type")
|
||||
}
|
||||
val name = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.LOWER_CAMEL, paramName)
|
||||
params.find(_.name == name) match {
|
||||
|
||||
@ -18,18 +18,21 @@ package ml.dmlc.xgboost4j.scala.spark
|
||||
|
||||
import java.nio.file.Files
|
||||
import java.util.concurrent.LinkedBlockingDeque
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
|
||||
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
|
||||
import ml.dmlc.xgboost4j.scala.DMatrix
|
||||
import ml.dmlc.xgboost4j.scala.rabit.RabitTracker
|
||||
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _}
|
||||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.ml.linalg.Vectors
|
||||
import org.apache.spark.sql._
|
||||
import org.scalatest.FunSuite
|
||||
import scala.util.Random
|
||||
|
||||
import ml.dmlc.xgboost4j.java.Rabit
|
||||
|
||||
class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
|
||||
test("test Rabit allreduce to validate Scala-implemented Rabit tracker") {
|
||||
@ -108,66 +111,89 @@ class XGBoostGeneralSuite extends FunSuite with PerTest {
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
|
||||
ignore("test with fast histo depthwise") {
|
||||
test("test with fast histo with monotone_constraints") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "1",
|
||||
val paramMap = Map("eta" -> "1",
|
||||
"max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
|
||||
// TODO: histogram algorithm seems to be very very sensitive to worker number
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "monotone_constraints" -> "(1, 0)")
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo lossguide") {
|
||||
test("test with fast histo with interaction_constraints") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1",
|
||||
"max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers, "interaction_constraints" -> "[[1,2],[2,3,4]]")
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("test with fast histo depthwise") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1",
|
||||
"max_depth" -> "6", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "depthwise",
|
||||
"num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
assert(eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM) < 0.1)
|
||||
}
|
||||
|
||||
test("test with fast histo lossguide") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "1",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist", "grow_policy" -> "lossguide",
|
||||
"max_leaves" -> "8", "eval_metric" -> "error", "num_round" -> 5,
|
||||
"num_workers" -> math.min(numWorkers, 2))
|
||||
"max_leaves" -> "8", "num_round" -> 5,
|
||||
"num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo lossguide with max bin") {
|
||||
test("test with fast histo lossguide with max bin") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "lossguide", "max_leaves" -> "8", "max_bin" -> "16",
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> math.min(numWorkers, 2))
|
||||
"eval_metric" -> "error", "num_round" -> 5, "num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo depthwidth with max depth") {
|
||||
test("test with fast histo depthwidth with max depth") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_leaves" -> "8", "max_depth" -> "2",
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2",
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
}
|
||||
|
||||
ignore("test with fast histo depthwidth with max depth and max bin") {
|
||||
test("test with fast histo depthwidth with max depth and max bin") {
|
||||
val eval = new EvalError()
|
||||
val training = buildDataFrame(Classification.train)
|
||||
val testDM = new DMatrix(Classification.test.iterator)
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "0", "silent" -> "0",
|
||||
val paramMap = Map("eta" -> "1", "gamma" -> "0.5", "max_depth" -> "6", "silent" -> "0",
|
||||
"objective" -> "binary:logistic", "tree_method" -> "hist",
|
||||
"grow_policy" -> "depthwise", "max_depth" -> "2", "max_bin" -> "2",
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> math.min(numWorkers, 2))
|
||||
"eval_metric" -> "error", "num_round" -> 10, "num_workers" -> numWorkers)
|
||||
val model = new XGBoostClassifier(paramMap).fit(training)
|
||||
val x = eval.eval(model._booster.predict(testDM, outPutMargin = true), testDM)
|
||||
assert(x < 0.1)
|
||||
|
||||
@ -382,7 +382,8 @@ public class BoosterImplTest {
|
||||
metrics, null, null, 0);
|
||||
for (int i = 0; i < metrics.length; i++)
|
||||
for (int j = 1; j < metrics[i].length; j++) {
|
||||
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1]);
|
||||
TestCase.assertTrue(metrics[i][j] >= metrics[i][j - 1] ||
|
||||
Math.abs(metrics[i][j] - metrics[i][j - 1]) < 0.1);
|
||||
}
|
||||
for (int i = 0; i < metrics.length; i++)
|
||||
for (int j = 0; j < metrics[i].length; j++) {
|
||||
|
||||
@ -83,9 +83,9 @@ void HistCutMatrix::Init
|
||||
summary_array[i].Reserve(max_num_bins * kFactor);
|
||||
summary_array[i].SetPrune(out, max_num_bins * kFactor);
|
||||
}
|
||||
CHECK_EQ(summary_array.size(), in_sketchs->size());
|
||||
size_t nbytes = WXQSketch::SummaryContainer::CalcMemCost(max_num_bins * kFactor);
|
||||
sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size());
|
||||
|
||||
this->min_val.resize(sketchs.size());
|
||||
row_ptr.push_back(0);
|
||||
for (size_t fid = 0; fid < summary_array.size(); ++fid) {
|
||||
@ -479,14 +479,14 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair>& gpair,
|
||||
|
||||
#pragma omp parallel for num_threads(std::min(nthread, n_blocks)) schedule(guided)
|
||||
for (bst_omp_uint iblock = 0; iblock < n_blocks; iblock++) {
|
||||
const size_t istart = iblock*block_size;
|
||||
const size_t iend = (((iblock+1)*block_size > size) ? size : istart + block_size);
|
||||
const size_t istart = iblock * block_size;
|
||||
const size_t iend = (((iblock + 1) * block_size > size) ? size : istart + block_size);
|
||||
|
||||
const size_t bin = 2*thread_init_[0]*nbins_;
|
||||
memcpy(hist_data + istart, (data + bin + istart), sizeof(double)*(iend - istart));
|
||||
const size_t bin = 2 * thread_init_[0] * nbins_;
|
||||
memcpy(hist_data + istart, (data + bin + istart), sizeof(double) * (iend - istart));
|
||||
|
||||
for (size_t i_bin_part = 1; i_bin_part < n_worked_bins; ++i_bin_part) {
|
||||
const size_t bin = 2*thread_init_[i_bin_part]*nbins_;
|
||||
const size_t bin = 2 * thread_init_[i_bin_part] * nbins_;
|
||||
for (size_t i = istart; i < iend; i++) {
|
||||
hist_data[i] += data[bin + i];
|
||||
}
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
#include "row_set.h"
|
||||
#include "../tree/param.h"
|
||||
#include "./quantile.h"
|
||||
#include "../include/rabit/rabit.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -43,6 +44,10 @@ struct GHistEntry {
|
||||
sum_hess += e.sum_hess;
|
||||
}
|
||||
|
||||
inline static void Reduce(GHistEntry& a, const GHistEntry& b) { // NOLINT(*)
|
||||
a.Add(b);
|
||||
}
|
||||
|
||||
/*! \brief set sum to be difference of two GHistEntry's */
|
||||
inline void SetSubtract(const GHistEntry& a, const GHistEntry& b) {
|
||||
sum_grad = a.sum_grad - b.sum_grad;
|
||||
@ -166,7 +171,7 @@ class GHistIndexBlockMatrix {
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief histogram of graident statistics for a single node.
|
||||
* \brief histogram of gradient statistics for a single node.
|
||||
* Consists of multiple GHistEntry's, each entry showing total graident statistics
|
||||
* for that particular bin
|
||||
* Uses global bin id so as to represent all features simultaneously
|
||||
@ -254,6 +259,10 @@ class GHistBuilder {
|
||||
// construct a histogram via subtraction trick
|
||||
void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent);
|
||||
|
||||
uint32_t GetNumBins() {
|
||||
return nbins_;
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief number of threads for parallel computation */
|
||||
size_t nthread_;
|
||||
|
||||
@ -17,6 +17,8 @@
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
|
||||
#include "io.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
/*!
|
||||
|
||||
@ -598,8 +598,8 @@ class LearnerImpl : public Learner {
|
||||
}
|
||||
|
||||
const TreeMethod current_tree_method = tparam_.tree_method;
|
||||
|
||||
if (rabit::IsDistributed()) {
|
||||
/* Choose tree_method='approx' when distributed training is activated */
|
||||
CHECK(tparam_.dsplit != DataSplitMode::kAuto)
|
||||
<< "Precondition violated; dsplit cannot be 'auto' in distributed mode";
|
||||
if (tparam_.dsplit == DataSplitMode::kCol) {
|
||||
@ -614,14 +614,13 @@ class LearnerImpl : public Learner {
|
||||
"for distributed training.";
|
||||
break;
|
||||
case TreeMethod::kApprox:
|
||||
case TreeMethod::kHist:
|
||||
// things are okay, do nothing
|
||||
break;
|
||||
case TreeMethod::kExact:
|
||||
case TreeMethod::kHist:
|
||||
LOG(WARNING) << "Tree method was set to be '"
|
||||
<< (current_tree_method == TreeMethod::kExact ?
|
||||
"exact" : "hist")
|
||||
<< "', but only 'approx' is available for distributed "
|
||||
LOG(CONSOLE) << "Tree method was set to be "
|
||||
<< "exact"
|
||||
<< "', but only 'approx' and 'hist' is available for distributed "
|
||||
"training. The `tree_method` parameter is now being "
|
||||
"changed to 'approx'";
|
||||
break;
|
||||
@ -633,7 +632,15 @@ class LearnerImpl : public Learner {
|
||||
LOG(FATAL) << "Unknown tree_method ("
|
||||
<< static_cast<int>(current_tree_method) << ") detected";
|
||||
}
|
||||
if (current_tree_method != TreeMethod::kHist) {
|
||||
LOG(CONSOLE) << "Tree method is automatically selected to be 'approx'"
|
||||
" for distributed training.";
|
||||
tparam_.tree_method = TreeMethod::kApprox;
|
||||
} else {
|
||||
LOG(CONSOLE) << "Tree method is specified to be 'hist'"
|
||||
" for distributed training.";
|
||||
tparam_.tree_method = TreeMethod::kHist;
|
||||
}
|
||||
} else if (!p_train->SingleColBlock()) {
|
||||
/* Some tree methods are not available for external-memory DMatrix */
|
||||
switch (current_tree_method) {
|
||||
|
||||
@ -126,6 +126,7 @@ class HistMaker: public BaseMaker {
|
||||
virtual void Update(const std::vector<GradientPair> &gpair,
|
||||
DMatrix *p_fmat,
|
||||
RegTree *p_tree) {
|
||||
CHECK(param_.max_depth > 0) << "max_depth must be larger than 0";
|
||||
this->InitData(gpair, *p_fmat, *p_tree);
|
||||
this->InitWorkSet(p_fmat, *p_tree, &fwork_set_);
|
||||
// mark root node as fresh.
|
||||
@ -345,10 +346,7 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
this->wspace_.Init(this->param_, 1);
|
||||
// if it is C++11, use lazy evaluation for Allreduce,
|
||||
// to gain speedup in recovery
|
||||
#if __cplusplus >= 201103L
|
||||
auto lazy_get_hist = [&]()
|
||||
#endif
|
||||
{
|
||||
auto lazy_get_hist = [&]() {
|
||||
thread_hist_.resize(omp_get_max_threads());
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetSortedColumnBatches()) {
|
||||
@ -371,20 +369,16 @@ class CQHistMaker: public HistMaker<TStats> {
|
||||
for (size_t i = 0; i < this->qexpand_.size(); ++i) {
|
||||
const int nid = this->qexpand_[i];
|
||||
const int wid = this->node2workindex_[nid];
|
||||
this->wspace_.hset[0][fset.size() + wid * (fset.size()+1)]
|
||||
this->wspace_.hset[0][fset.size() + wid * (fset.size() + 1)]
|
||||
.data[0] = node_stats_[nid];
|
||||
}
|
||||
};
|
||||
// sync the histogram
|
||||
// if it is C++11, use lazy evaluation for Allreduce
|
||||
#if __cplusplus >= 201103L
|
||||
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
|
||||
this->wspace_.hset[0].data.size(), lazy_get_hist);
|
||||
#else
|
||||
this->histred_.Allreduce(dmlc::BeginPtr(this->wspace_.hset[0].data),
|
||||
this->wspace_.hset[0].data.size());
|
||||
#endif
|
||||
}
|
||||
|
||||
void ResetPositionAfterSplit(DMatrix *p_fmat,
|
||||
const RegTree &tree) override {
|
||||
this->GetSplitSet(this->qexpand_, tree, &fsplit_set_);
|
||||
|
||||
@ -156,6 +156,11 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||
const int cright = (*p_tree)[nid].RightChild();
|
||||
hist_.AddHistRow(cleft);
|
||||
hist_.AddHistRow(cright);
|
||||
if (rabit::IsDistributed()) {
|
||||
// in distributed mode, we need to keep consistent across workers
|
||||
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]);
|
||||
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
|
||||
} else {
|
||||
if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) {
|
||||
BuildHist(gpair_h, row_set_collection_[cleft], gmat, gmatb, hist_[cleft]);
|
||||
SubtractionTrick(hist_[cright], hist_[cleft], hist_[nid]);
|
||||
@ -163,6 +168,7 @@ void QuantileHistMaker::Builder::Update(const GHistIndexMatrix& gmat,
|
||||
BuildHist(gpair_h, row_set_collection_[cright], gmat, gmatb, hist_[cright]);
|
||||
SubtractionTrick(hist_[cleft], hist_[cright], hist_[nid]);
|
||||
}
|
||||
}
|
||||
time_build_hist += dmlc::GetTime() - tstart;
|
||||
|
||||
tstart = dmlc::GetTime();
|
||||
@ -617,11 +623,21 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
|
||||
{
|
||||
auto& stats = snode_[nid].stats;
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
|
||||
GHistRow hist = hist_[nid];
|
||||
if (rabit::IsDistributed()) {
|
||||
// in distributed mode, the node's stats should be calculated from histogram, otherwise,
|
||||
// we will have wrong results in EnumerateSplit()
|
||||
// here we take the last feature in cut
|
||||
for (size_t i = gmat.cut.row_ptr[0]; i < gmat.cut.row_ptr[1]; i++) {
|
||||
stats.Add(hist.begin[i].sum_grad, hist.begin[i].sum_hess);
|
||||
}
|
||||
} else {
|
||||
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased ||
|
||||
rabit::IsDistributed()) {
|
||||
/* specialized code for dense data
|
||||
For dense data (with no missing value),
|
||||
the sum of gradient histogram is equal to snode[nid] */
|
||||
GHistRow hist = hist_[nid];
|
||||
the sum of gradient histogram is equal to snode[nid]
|
||||
GHistRow hist = hist_[nid];*/
|
||||
const std::vector<uint32_t>& row_ptr = gmat.cut.row_ptr;
|
||||
|
||||
const uint32_t ibegin = row_ptr[fid_least_bins_];
|
||||
@ -637,6 +653,7 @@ void QuantileHistMaker::Builder::InitNewNode(int nid,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// calculating the weights
|
||||
{
|
||||
|
||||
@ -105,6 +105,7 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
} else {
|
||||
hist_builder_.BuildHist(gpair, row_indices, gmat, hist);
|
||||
}
|
||||
this->histred_.Allreduce(hist.begin, hist_builder_.GetNumBins());
|
||||
}
|
||||
|
||||
inline void SubtractionTrick(GHistRow self, GHistRow sibling, GHistRow parent) {
|
||||
@ -225,6 +226,8 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
enum DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData };
|
||||
DataLayout data_layout_;
|
||||
|
||||
rabit::Reducer<GHistEntry, GHistEntry::Reduce> histred_;
|
||||
};
|
||||
|
||||
std::unique_ptr<Builder> builder_;
|
||||
|
||||
@ -52,10 +52,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
}
|
||||
// if it is C++11, use lazy evaluation for Allreduce,
|
||||
// to gain speedup in recovery
|
||||
#if __cplusplus >= 201103L
|
||||
auto lazy_get_stats = [&]()
|
||||
#endif
|
||||
{
|
||||
auto lazy_get_stats = [&]() {
|
||||
const MetaInfo &info = p_fmat->Info();
|
||||
// start accumulating statistics
|
||||
for (const auto &batch : p_fmat->GetRowBatches()) {
|
||||
@ -86,11 +83,7 @@ class TreeRefresher: public TreeUpdater {
|
||||
}
|
||||
}
|
||||
};
|
||||
#if __cplusplus >= 201103L
|
||||
reducer_.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size(), lazy_get_stats);
|
||||
#else
|
||||
reducer_.Allreduce(dmlc::BeginPtr(stemp[0]), stemp[0].size());
|
||||
#endif
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user