Small refactor for hist builder. (#8698)
- Use span instead of vector as parameter. No perf change as the builder work on pointer. - Use const pointer for reg tree.
This commit is contained in:
parent
8af98e30fc
commit
21a28f2cc5
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2020 by Contributors
|
* Copyright 2017-2023 by XGBoost Contributors
|
||||||
* \file hist_util.cc
|
* \file hist_util.cc
|
||||||
*/
|
*/
|
||||||
#include <dmlc/timer.h>
|
#include <dmlc/timer.h>
|
||||||
@ -193,9 +193,9 @@ class GHistBuildingManager {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <bool do_prefetch, class BuildingManager>
|
template <bool do_prefetch, class BuildingManager>
|
||||||
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
void RowsWiseBuildHistKernel(Span<GradientPair const> gpair,
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||||
GHistRow hist) {
|
GHistRow hist) {
|
||||||
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||||
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||||
using BinIdxType = typename BuildingManager::BinIdxType;
|
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||||
@ -262,9 +262,9 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class BuildingManager>
|
template <class BuildingManager>
|
||||||
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
void ColsWiseBuildHistKernel(Span<GradientPair const> gpair,
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||||
GHistRow hist) {
|
GHistRow hist) {
|
||||||
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
|
||||||
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
constexpr bool kFirstPage = BuildingManager::kFirstPage;
|
||||||
using BinIdxType = typename BuildingManager::BinIdxType;
|
using BinIdxType = typename BuildingManager::BinIdxType;
|
||||||
@ -315,9 +315,8 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <class BuildingManager>
|
template <class BuildingManager>
|
||||||
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
void BuildHistDispatch(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
|
||||||
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
const GHistIndexMatrix &gmat, GHistRow hist) {
|
||||||
GHistRow hist) {
|
|
||||||
if (BuildingManager::kReadByColumn) {
|
if (BuildingManager::kReadByColumn) {
|
||||||
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
|
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||||
} else {
|
} else {
|
||||||
@ -344,33 +343,31 @@ void BuildHistDispatch(const std::vector<GradientPair> &gpair,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
|
void GHistBuilder::BuildHist(Span<GradientPair const> gpair,
|
||||||
const RowSetCollection::Elem row_indices,
|
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
|
||||||
const GHistIndexMatrix &gmat,
|
|
||||||
GHistRow hist, bool force_read_by_column) const {
|
GHistRow hist, bool force_read_by_column) const {
|
||||||
/* force_read_by_column is used for testing the columnwise building of histograms.
|
/* force_read_by_column is used for testing the columnwise building of histograms.
|
||||||
* default force_read_by_column = false
|
* default force_read_by_column = false
|
||||||
*/
|
*/
|
||||||
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
|
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
|
||||||
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
|
const bool hist_fit_to_l2 = kAdhocL2Size > 2 * sizeof(float) * gmat.cut.Ptrs().back();
|
||||||
bool first_page = gmat.base_rowid == 0;
|
bool first_page = gmat.base_rowid == 0;
|
||||||
bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
bool read_by_column = !hist_fit_to_l2 && !any_missing;
|
||||||
auto bin_type_size = gmat.index.GetBinTypeSize();
|
auto bin_type_size = gmat.index.GetBinTypeSize();
|
||||||
|
|
||||||
GHistBuildingManager<any_missing>::DispatchAndExecute(
|
GHistBuildingManager<any_missing>::DispatchAndExecute(
|
||||||
{first_page, read_by_column || force_read_by_column, bin_type_size},
|
{first_page, read_by_column || force_read_by_column, bin_type_size}, [&](auto t) {
|
||||||
[&](auto t) {
|
using BuildingManager = decltype(t);
|
||||||
using BuildingManager = decltype(t);
|
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
|
||||||
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
|
});
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
|
template void GHistBuilder::BuildHist<true>(Span<GradientPair const> gpair,
|
||||||
const RowSetCollection::Elem row_indices,
|
const RowSetCollection::Elem row_indices,
|
||||||
const GHistIndexMatrix &gmat, GHistRow hist,
|
const GHistIndexMatrix &gmat, GHistRow hist,
|
||||||
bool force_read_by_column) const;
|
bool force_read_by_column) const;
|
||||||
|
|
||||||
template void GHistBuilder::BuildHist<false>(const std::vector<GradientPair> &gpair,
|
template void GHistBuilder::BuildHist<false>(Span<GradientPair const> gpair,
|
||||||
const RowSetCollection::Elem row_indices,
|
const RowSetCollection::Elem row_indices,
|
||||||
const GHistIndexMatrix &gmat, GHistRow hist,
|
const GHistIndexMatrix &gmat, GHistRow hist,
|
||||||
bool force_read_by_column) const;
|
bool force_read_by_column) const;
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2017-2022 by XGBoost Contributors
|
* Copyright 2017-2023 by XGBoost Contributors
|
||||||
* \file hist_util.h
|
* \file hist_util.h
|
||||||
* \brief Utility for fast histogram aggregation
|
* \brief Utility for fast histogram aggregation
|
||||||
* \author Philip Cho, Tianqi Chen
|
* \author Philip Cho, Tianqi Chen
|
||||||
@ -23,6 +23,7 @@
|
|||||||
#include "row_set.h"
|
#include "row_set.h"
|
||||||
#include "threading_utils.h"
|
#include "threading_utils.h"
|
||||||
#include "timer.h"
|
#include "timer.h"
|
||||||
|
#include "xgboost/base.h" // bst_feature_t, bst_bin_t
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
class GHistIndexMatrix;
|
class GHistIndexMatrix;
|
||||||
@ -320,10 +321,10 @@ struct Index {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename GradientIndex>
|
template <typename GradientIndex>
|
||||||
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end,
|
bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t end,
|
||||||
GradientIndex const& data,
|
GradientIndex const& data,
|
||||||
uint32_t const fidx_begin,
|
bst_feature_t const fidx_begin,
|
||||||
uint32_t const fidx_end) {
|
bst_feature_t const fidx_end) {
|
||||||
size_t previous_middle = std::numeric_limits<size_t>::max();
|
size_t previous_middle = std::numeric_limits<size_t>::max();
|
||||||
while (end != begin) {
|
while (end != begin) {
|
||||||
size_t middle = begin + (end - begin) / 2;
|
size_t middle = begin + (end - begin) / 2;
|
||||||
@ -635,7 +636,7 @@ class GHistBuilder {
|
|||||||
|
|
||||||
// construct a histogram via histogram aggregation
|
// construct a histogram via histogram aggregation
|
||||||
template <bool any_missing>
|
template <bool any_missing>
|
||||||
void BuildHist(const std::vector<GradientPair>& gpair, const RowSetCollection::Elem row_indices,
|
void BuildHist(Span<GradientPair const> gpair, const RowSetCollection::Elem row_indices,
|
||||||
const GHistIndexMatrix& gmat, GHistRow hist,
|
const GHistIndexMatrix& gmat, GHistRow hist,
|
||||||
bool force_read_by_column = false) const;
|
bool force_read_by_column = false) const;
|
||||||
uint32_t GetNumBins() const {
|
uint32_t GetNumBins() const {
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
/*!
|
/**
|
||||||
* Copyright 2021-2022 by XGBoost Contributors
|
* Copyright 2021-2023 by XGBoost Contributors
|
||||||
*/
|
*/
|
||||||
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
|
#ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||||
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
|
#define XGBOOST_TREE_HIST_HISTOGRAM_H_
|
||||||
@ -59,15 +59,14 @@ class HistogramBuilder {
|
|||||||
GHistIndexMatrix const &gidx,
|
GHistIndexMatrix const &gidx,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
common::RowSetCollection const &row_set_collection,
|
common::RowSetCollection const &row_set_collection,
|
||||||
const std::vector<GradientPair> &gpair_h,
|
common::Span<GradientPair const> gpair_h, bool force_read_by_column) {
|
||||||
bool force_read_by_column) {
|
|
||||||
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
||||||
CHECK_GT(n_nodes, 0);
|
CHECK_GT(n_nodes, 0);
|
||||||
|
|
||||||
std::vector<common::GHistRow> target_hists(n_nodes);
|
std::vector<common::GHistRow> target_hists(n_nodes);
|
||||||
for (size_t i = 0; i < n_nodes; ++i) {
|
for (size_t i = 0; i < n_nodes; ++i) {
|
||||||
const int32_t nid = nodes_for_explicit_hist_build[i].nid;
|
auto const nidx = nodes_for_explicit_hist_build[i].nid;
|
||||||
target_hists[i] = hist_[nid];
|
target_hists[i] = hist_[nidx];
|
||||||
}
|
}
|
||||||
if (page_idx == 0) {
|
if (page_idx == 0) {
|
||||||
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
|
// FIXME(jiamingy): Handle different size of space. Right now we use the maximum
|
||||||
@ -93,46 +92,37 @@ class HistogramBuilder {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void
|
void AddHistRows(int *starting_index, int *sync_count,
|
||||||
AddHistRows(int *starting_index, int *sync_count,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
RegTree const *p_tree) {
|
||||||
RegTree *p_tree) {
|
|
||||||
if (is_distributed_) {
|
if (is_distributed_) {
|
||||||
this->AddHistRowsDistributed(starting_index, sync_count,
|
this->AddHistRowsDistributed(starting_index, sync_count, nodes_for_explicit_hist_build,
|
||||||
nodes_for_explicit_hist_build,
|
|
||||||
nodes_for_subtraction_trick, p_tree);
|
nodes_for_subtraction_trick, p_tree);
|
||||||
} else {
|
} else {
|
||||||
this->AddHistRowsLocal(starting_index, sync_count,
|
this->AddHistRowsLocal(starting_index, sync_count, nodes_for_explicit_hist_build,
|
||||||
nodes_for_explicit_hist_build,
|
|
||||||
nodes_for_subtraction_trick);
|
nodes_for_subtraction_trick);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Main entry point of this class, build histogram for tree nodes. */
|
/** Main entry point of this class, build histogram for tree nodes. */
|
||||||
void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
|
void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx,
|
||||||
RegTree *p_tree, common::RowSetCollection const &row_set_collection,
|
RegTree const *p_tree, common::RowSetCollection const &row_set_collection,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||||
std::vector<GradientPair> const &gpair,
|
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
|
||||||
bool force_read_by_column = false) {
|
|
||||||
int starting_index = std::numeric_limits<int>::max();
|
int starting_index = std::numeric_limits<int>::max();
|
||||||
int sync_count = 0;
|
int sync_count = 0;
|
||||||
if (page_id == 0) {
|
if (page_id == 0) {
|
||||||
this->AddHistRows(&starting_index, &sync_count,
|
this->AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build,
|
||||||
nodes_for_explicit_hist_build,
|
|
||||||
nodes_for_subtraction_trick, p_tree);
|
nodes_for_subtraction_trick, p_tree);
|
||||||
}
|
}
|
||||||
if (gidx.IsDense()) {
|
if (gidx.IsDense()) {
|
||||||
this->BuildLocalHistograms<false>(page_id, space, gidx,
|
this->BuildLocalHistograms<false>(page_id, space, gidx, nodes_for_explicit_hist_build,
|
||||||
nodes_for_explicit_hist_build,
|
row_set_collection, gpair, force_read_by_column);
|
||||||
row_set_collection, gpair,
|
|
||||||
force_read_by_column);
|
|
||||||
} else {
|
} else {
|
||||||
this->BuildLocalHistograms<true>(page_id, space, gidx,
|
this->BuildLocalHistograms<true>(page_id, space, gidx, nodes_for_explicit_hist_build,
|
||||||
nodes_for_explicit_hist_build,
|
row_set_collection, gpair, force_read_by_column);
|
||||||
row_set_collection, gpair,
|
|
||||||
force_read_by_column);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK_GE(n_batches_, 1);
|
CHECK_GE(n_batches_, 1);
|
||||||
@ -153,8 +143,7 @@ class HistogramBuilder {
|
|||||||
common::RowSetCollection const &row_set_collection,
|
common::RowSetCollection const &row_set_collection,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||||
std::vector<GradientPair> const &gpair,
|
common::Span<GradientPair const> gpair, bool force_read_by_column = false) {
|
||||||
bool force_read_by_column = false) {
|
|
||||||
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
const size_t n_nodes = nodes_for_explicit_hist_build.size();
|
||||||
// create space of size (# rows in each node)
|
// create space of size (# rows in each node)
|
||||||
common::BlockedSpace2d space(
|
common::BlockedSpace2d space(
|
||||||
@ -164,83 +153,72 @@ class HistogramBuilder {
|
|||||||
return row_set_collection[nidx].Size();
|
return row_set_collection[nidx].Size();
|
||||||
},
|
},
|
||||||
256);
|
256);
|
||||||
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection,
|
this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, nodes_for_explicit_hist_build,
|
||||||
nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
|
nodes_for_subtraction_trick, gpair, force_read_by_column);
|
||||||
gpair, force_read_by_column);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SyncHistogramDistributed(
|
void SyncHistogramDistributed(RegTree const *p_tree,
|
||||||
RegTree *p_tree,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
int starting_index, int sync_count) {
|
||||||
int starting_index, int sync_count) {
|
|
||||||
const size_t nbins = builder_.GetNumBins();
|
const size_t nbins = builder_.GetNumBins();
|
||||||
common::BlockedSpace2d space(
|
common::BlockedSpace2d space(
|
||||||
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; },
|
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
|
||||||
1024);
|
common::ParallelFor2d(space, n_threads_, [&](size_t node, common::Range1d r) {
|
||||||
common::ParallelFor2d(
|
const auto &entry = nodes_for_explicit_hist_build[node];
|
||||||
space, n_threads_, [&](size_t node, common::Range1d r) {
|
auto this_hist = this->hist_[entry.nid];
|
||||||
const auto &entry = nodes_for_explicit_hist_build[node];
|
// Merging histograms from each thread into once
|
||||||
auto this_hist = this->hist_[entry.nid];
|
buffer_.ReduceHist(node, r.begin(), r.end());
|
||||||
// Merging histograms from each thread into once
|
// Store posible parent node
|
||||||
buffer_.ReduceHist(node, r.begin(), r.end());
|
auto this_local = hist_local_worker_[entry.nid];
|
||||||
// Store posible parent node
|
common::CopyHist(this_local, this_hist, r.begin(), r.end());
|
||||||
auto this_local = hist_local_worker_[entry.nid];
|
|
||||||
common::CopyHist(this_local, this_hist, r.begin(), r.end());
|
|
||||||
|
|
||||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||||
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
||||||
const int subtraction_node_id =
|
const int subtraction_node_id = nodes_for_subtraction_trick[node].nid;
|
||||||
nodes_for_subtraction_trick[node].nid;
|
auto parent_hist = this->hist_local_worker_[parent_id];
|
||||||
auto parent_hist = this->hist_local_worker_[parent_id];
|
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||||
common::SubtractionHist(sibling_hist, parent_hist, this_hist,
|
// Store posible parent node
|
||||||
r.begin(), r.end());
|
auto sibling_local = hist_local_worker_[subtraction_node_id];
|
||||||
// Store posible parent node
|
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
|
||||||
auto sibling_local = hist_local_worker_[subtraction_node_id];
|
}
|
||||||
common::CopyHist(sibling_local, sibling_hist, r.begin(), r.end());
|
});
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
reinterpret_cast<double *>(this->hist_[starting_index].data()),
|
reinterpret_cast<double *>(this->hist_[starting_index].data()),
|
||||||
builder_.GetNumBins() * sync_count * 2);
|
builder_.GetNumBins() * sync_count * 2);
|
||||||
|
|
||||||
ParallelSubtractionHist(space, nodes_for_explicit_hist_build,
|
ParallelSubtractionHist(space, nodes_for_explicit_hist_build, nodes_for_subtraction_trick,
|
||||||
nodes_for_subtraction_trick, p_tree);
|
p_tree);
|
||||||
|
|
||||||
common::BlockedSpace2d space2(
|
common::BlockedSpace2d space2(
|
||||||
nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; },
|
nodes_for_subtraction_trick.size(), [&](size_t) { return nbins; }, 1024);
|
||||||
1024);
|
ParallelSubtractionHist(space2, nodes_for_subtraction_trick, nodes_for_explicit_hist_build,
|
||||||
ParallelSubtractionHist(space2, nodes_for_subtraction_trick,
|
p_tree);
|
||||||
nodes_for_explicit_hist_build, p_tree);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SyncHistogramLocal(RegTree *p_tree,
|
void SyncHistogramLocal(RegTree const *p_tree,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick) {
|
||||||
const size_t nbins = this->builder_.GetNumBins();
|
const size_t nbins = this->builder_.GetNumBins();
|
||||||
common::BlockedSpace2d space(
|
common::BlockedSpace2d space(
|
||||||
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; },
|
nodes_for_explicit_hist_build.size(), [&](size_t) { return nbins; }, 1024);
|
||||||
1024);
|
|
||||||
|
|
||||||
common::ParallelFor2d(
|
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
||||||
space, this->n_threads_, [&](size_t node, common::Range1d r) {
|
const auto &entry = nodes_for_explicit_hist_build[node];
|
||||||
const auto &entry = nodes_for_explicit_hist_build[node];
|
auto this_hist = this->hist_[entry.nid];
|
||||||
auto this_hist = this->hist_[entry.nid];
|
// Merging histograms from each thread into once
|
||||||
// Merging histograms from each thread into once
|
this->buffer_.ReduceHist(node, r.begin(), r.end());
|
||||||
this->buffer_.ReduceHist(node, r.begin(), r.end());
|
|
||||||
|
|
||||||
if (!(*p_tree)[entry.nid].IsRoot()) {
|
if (!(*p_tree)[entry.nid].IsRoot()) {
|
||||||
const size_t parent_id = (*p_tree)[entry.nid].Parent();
|
auto const parent_id = (*p_tree)[entry.nid].Parent();
|
||||||
const int subtraction_node_id =
|
auto const subtraction_node_id = nodes_for_subtraction_trick[node].nid;
|
||||||
nodes_for_subtraction_trick[node].nid;
|
auto parent_hist = this->hist_[parent_id];
|
||||||
auto parent_hist = this->hist_[parent_id];
|
auto sibling_hist = this->hist_[subtraction_node_id];
|
||||||
auto sibling_hist = this->hist_[subtraction_node_id];
|
common::SubtractionHist(sibling_hist, parent_hist, this_hist, r.begin(), r.end());
|
||||||
common::SubtractionHist(sibling_hist, parent_hist, this_hist,
|
}
|
||||||
r.begin(), r.end());
|
});
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
@ -289,11 +267,10 @@ class HistogramBuilder {
|
|||||||
this->hist_.AllocateAllData();
|
this->hist_.AllocateAllData();
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddHistRowsDistributed(
|
void AddHistRowsDistributed(int *starting_index, int *sync_count,
|
||||||
int *starting_index, int *sync_count,
|
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
||||||
std::vector<ExpandEntry> const &nodes_for_explicit_hist_build,
|
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
||||||
std::vector<ExpandEntry> const &nodes_for_subtraction_trick,
|
RegTree const *p_tree) {
|
||||||
RegTree *p_tree) {
|
|
||||||
const size_t explicit_size = nodes_for_explicit_hist_build.size();
|
const size_t explicit_size = nodes_for_explicit_hist_build.size();
|
||||||
const size_t subtaction_size = nodes_for_subtraction_trick.size();
|
const size_t subtaction_size = nodes_for_subtraction_trick.size();
|
||||||
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
|
std::vector<int> merged_node_ids(explicit_size + subtaction_size);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user