Initial support for multi-target tree. (#8616)
* Implement multi-target for hist. - Add new hist tree builder. - Move data fetchers for tests. - Dispatch function calls in gbm base on the tree type.
This commit is contained in:
@@ -306,9 +306,9 @@ class HistogramBuilder {
|
||||
|
||||
// Construct a work space for building histogram. Eventually we should move this
|
||||
// function into histogram builder once hist tree method supports external memory.
|
||||
template <typename Partitioner>
|
||||
template <typename Partitioner, typename ExpandEntry = CPUExpandEntry>
|
||||
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
||||
std::vector<CPUExpandEntry> const &nodes_to_build) {
|
||||
std::vector<ExpandEntry> const &nodes_to_build) {
|
||||
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
|
||||
for (auto const &partition : partitioners) {
|
||||
size_t k = 0;
|
||||
|
||||
@@ -889,6 +889,8 @@ void RegTree::Save(dmlc::Stream* fo) const {
|
||||
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
||||
CHECK_EQ(param_.deprecated_num_roots, 1);
|
||||
CHECK_NE(param_.num_nodes, 0);
|
||||
CHECK(!IsMultiTarget())
|
||||
<< "Please use JSON/UBJSON for saving models with multi-target trees.";
|
||||
CHECK(!HasCategoricalSplit())
|
||||
<< "Please use JSON/UBJSON for saving models with categorical splits.";
|
||||
|
||||
|
||||
@@ -4,36 +4,39 @@
|
||||
* \brief use quantized feature values to construct a tree
|
||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||
*/
|
||||
#include <algorithm> // for max
|
||||
#include <algorithm> // for max, copy, transform
|
||||
#include <cstddef> // for size_t
|
||||
#include <cstdint> // for uint32_t
|
||||
#include <memory> // for unique_ptr, allocator, make_unique, make_shared
|
||||
#include <ostream> // for operator<<, char_traits, basic_ostream
|
||||
#include <tuple> // for apply
|
||||
#include <cstdint> // for uint32_t, int32_t
|
||||
#include <memory> // for unique_ptr, allocator, make_unique, shared_ptr
|
||||
#include <numeric> // for accumulate
|
||||
#include <ostream> // for basic_ostream, char_traits, operator<<
|
||||
#include <utility> // for move, swap
|
||||
#include <vector> // for vector
|
||||
|
||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
||||
#include "../collective/communicator.h" // for Operation
|
||||
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
||||
#include "../common/linalg_op.h" // for begin, cbegin, cend
|
||||
#include "../common/random.h" // for ColumnSampler
|
||||
#include "../common/threading_utils.h" // for ParallelFor
|
||||
#include "../common/timer.h" // for Monitor
|
||||
#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter
|
||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
||||
#include "dmlc/omp.h" // for omp_get_thread_num
|
||||
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||
#include "driver.h" // for Driver
|
||||
#include "hist/evaluate_splits.h" // for HistEvaluator, UpdatePredictionCacheImpl
|
||||
#include "hist/expand_entry.h" // for CPUExpandEntry
|
||||
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
|
||||
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
|
||||
#include "hist/sampler.h" // for SampleGradient
|
||||
#include "param.h" // for TrainParam, GradStats
|
||||
#include "xgboost/base.h" // for GradientPair, GradientPairInternal, bst_node_t
|
||||
#include "param.h" // for TrainParam, SplitEntryContainer, GradStats
|
||||
#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ...
|
||||
#include "xgboost/context.h" // for Context
|
||||
#include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo
|
||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||
#include "xgboost/linalg.h" // for TensorView, MatrixView, UnravelIndex, All
|
||||
#include "xgboost/logging.h" // for LogCheck_EQ, LogCheck_GE, CHECK_EQ, LOG, LOG...
|
||||
#include "xgboost/linalg.h" // for All, MatrixView, TensorView, Matrix, Empty
|
||||
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE
|
||||
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
|
||||
#include "xgboost/string_view.h" // for operator<<
|
||||
#include "xgboost/task.h" // for ObjInfo
|
||||
@@ -105,6 +108,212 @@ void UpdateTree(common::Monitor *monitor_, linalg::MatrixView<GradientPair const
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Updater for building multi-target trees. The implementation simply iterates over
|
||||
* each target.
|
||||
*/
|
||||
class MultiTargetHistBuilder {
|
||||
private:
|
||||
common::Monitor *monitor_{nullptr};
|
||||
TrainParam const *param_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||
std::unique_ptr<HistMultiEvaluator> evaluator_;
|
||||
// Histogram builder for each target.
|
||||
std::vector<HistogramBuilder<MultiExpandEntry>> histogram_builder_;
|
||||
Context const *ctx_{nullptr};
|
||||
// Partitioner for each data batch.
|
||||
std::vector<CommonRowPartitioner> partitioner_;
|
||||
// Pointer to last updated tree, used for update prediction cache.
|
||||
RegTree const *p_last_tree_{nullptr};
|
||||
|
||||
ObjInfo const *task_{nullptr};
|
||||
|
||||
public:
|
||||
void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree,
|
||||
std::vector<MultiExpandEntry> const &applied) {
|
||||
monitor_->Start(__func__);
|
||||
std::size_t page_id{0};
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(this->param_))) {
|
||||
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
|
||||
page_id++;
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
|
||||
this->evaluator_->ApplyTreeSplit(candidate, p_tree);
|
||||
}
|
||||
|
||||
void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
|
||||
monitor_->Start(__func__);
|
||||
|
||||
std::size_t page_id = 0;
|
||||
bst_bin_t n_total_bins = 0;
|
||||
partitioner_.clear();
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
if (n_total_bins == 0) {
|
||||
n_total_bins = page.cut.TotalBins();
|
||||
} else {
|
||||
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||
}
|
||||
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||
page_id++;
|
||||
}
|
||||
|
||||
bst_target_t n_targets = p_tree->NumTargets();
|
||||
histogram_builder_.clear();
|
||||
for (std::size_t i = 0; i < n_targets; ++i) {
|
||||
histogram_builder_.emplace_back();
|
||||
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||
}
|
||||
|
||||
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
||||
p_last_tree_ = p_tree;
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
MultiExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView<GradientPair const> gpair,
|
||||
RegTree *p_tree) {
|
||||
monitor_->Start(__func__);
|
||||
MultiExpandEntry best;
|
||||
best.nid = RegTree::kRoot;
|
||||
best.depth = 0;
|
||||
|
||||
auto n_targets = p_tree->NumTargets();
|
||||
linalg::Matrix<GradientPairPrecise> root_sum_tloc =
|
||||
linalg::Empty<GradientPairPrecise>(ctx_, ctx_->Threads(), n_targets);
|
||||
CHECK_EQ(root_sum_tloc.Shape(1), gpair.Shape(1));
|
||||
auto h_root_sum_tloc = root_sum_tloc.HostView();
|
||||
common::ParallelFor(gpair.Shape(0), ctx_->Threads(), [&](auto i) {
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
h_root_sum_tloc(omp_get_thread_num(), t) += GradientPairPrecise{gpair(i, t)};
|
||||
}
|
||||
});
|
||||
// Aggregate to the first row.
|
||||
auto root_sum = h_root_sum_tloc.Slice(0, linalg::All());
|
||||
for (std::int32_t tidx{1}; tidx < ctx_->Threads(); ++tidx) {
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
root_sum(t) += h_root_sum_tloc(tidx, t);
|
||||
}
|
||||
}
|
||||
CHECK(root_sum.CContiguous());
|
||||
collective::Allreduce<collective::Operation::kSum>(
|
||||
reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2);
|
||||
|
||||
std::vector<MultiExpandEntry> nodes{best};
|
||||
std::size_t i = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||
nodes, {}, t_gpair.Values());
|
||||
}
|
||||
i++;
|
||||
}
|
||||
|
||||
auto weight = evaluator_->InitRoot(root_sum);
|
||||
auto weight_t = weight.HostView();
|
||||
std::transform(linalg::cbegin(weight_t), linalg::cend(weight_t), linalg::begin(weight_t),
|
||||
[&](float w) { return w * param_->learning_rate; });
|
||||
|
||||
p_tree->SetLeaf(RegTree::kRoot, weight_t);
|
||||
std::vector<common::HistCollection const *> hists;
|
||||
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||
hists.push_back(&histogram_builder_[t].Histogram());
|
||||
}
|
||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
|
||||
break;
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
|
||||
return nodes.front();
|
||||
}
|
||||
|
||||
void BuildHistogram(DMatrix *p_fmat, RegTree const *p_tree,
|
||||
std::vector<MultiExpandEntry> const &valid_candidates,
|
||||
linalg::MatrixView<GradientPair const> gpair) {
|
||||
monitor_->Start(__func__);
|
||||
std::vector<MultiExpandEntry> nodes_to_build;
|
||||
std::vector<MultiExpandEntry> nodes_to_sub;
|
||||
|
||||
for (auto const &c : valid_candidates) {
|
||||
auto left_nidx = p_tree->LeftChild(c.nid);
|
||||
auto right_nidx = p_tree->RightChild(c.nid);
|
||||
|
||||
auto build_nidx = left_nidx;
|
||||
auto subtract_nidx = right_nidx;
|
||||
auto lit =
|
||||
common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); });
|
||||
auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0);
|
||||
auto rit =
|
||||
common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); });
|
||||
auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0);
|
||||
auto fewer_right = right_sum < left_sum;
|
||||
if (fewer_right) {
|
||||
std::swap(build_nidx, subtract_nidx);
|
||||
}
|
||||
nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx));
|
||||
nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx));
|
||||
}
|
||||
|
||||
std::size_t i = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||
// Make sure the gradient matrix is f-order.
|
||||
CHECK(t_gpair.Contiguous());
|
||||
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||
nodes_to_build, nodes_to_sub, t_gpair.Values());
|
||||
}
|
||||
i++;
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
|
||||
std::vector<MultiExpandEntry> *best_splits) {
|
||||
monitor_->Start(__func__);
|
||||
std::vector<common::HistCollection const *> hists;
|
||||
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||
hists.push_back(&histogram_builder_[t].Histogram());
|
||||
}
|
||||
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
|
||||
break;
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
void LeafPartition(RegTree const &tree, linalg::MatrixView<GradientPair const> gpair,
|
||||
std::vector<bst_node_t> *p_out_position) {
|
||||
monitor_->Start(__func__);
|
||||
if (!task_->UpdateTreeLeaf()) {
|
||||
return;
|
||||
}
|
||||
for (auto const &part : partitioner_) {
|
||||
part.LeafPartition(ctx_, tree, gpair, p_out_position);
|
||||
}
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param,
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||
ObjInfo const *task, common::Monitor *monitor)
|
||||
: monitor_{monitor},
|
||||
param_{param},
|
||||
col_sampler_{std::move(column_sampler)},
|
||||
evaluator_{std::make_unique<HistMultiEvaluator>(ctx, info, param, col_sampler_)},
|
||||
ctx_{ctx},
|
||||
task_{task} {
|
||||
monitor_->Init(__func__);
|
||||
}
|
||||
};
|
||||
|
||||
class HistBuilder {
|
||||
private:
|
||||
common::Monitor *monitor_;
|
||||
@@ -155,8 +364,7 @@ class HistBuilder {
|
||||
// initialize temp data structure
|
||||
void InitData(DMatrix *fmat, RegTree const *p_tree) {
|
||||
monitor_->Start(__func__);
|
||||
|
||||
size_t page_id{0};
|
||||
std::size_t page_id{0};
|
||||
bst_bin_t n_total_bins{0};
|
||||
partitioner_.clear();
|
||||
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
@@ -195,7 +403,7 @@ class HistBuilder {
|
||||
RegTree *p_tree) {
|
||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
||||
|
||||
size_t page_id = 0;
|
||||
std::size_t page_id = 0;
|
||||
auto space = ConstructHistSpace(partitioner_, {node});
|
||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
std::vector<CPUExpandEntry> nodes_to_build{node};
|
||||
@@ -214,13 +422,13 @@ class HistBuilder {
|
||||
* of gradient histogram is equal to snode[nid]
|
||||
*/
|
||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
|
||||
std::vector<uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
||||
std::vector<std::uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
||||
CHECK_GE(row_ptr.size(), 2);
|
||||
uint32_t const ibegin = row_ptr[0];
|
||||
uint32_t const iend = row_ptr[1];
|
||||
std::uint32_t const ibegin = row_ptr[0];
|
||||
std::uint32_t const iend = row_ptr[1];
|
||||
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
|
||||
auto begin = hist.data();
|
||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
||||
for (std::uint32_t i = ibegin; i < iend; ++i) {
|
||||
GradientPairPrecise const &et = begin[i];
|
||||
grad_stat.Add(et.GetGrad(), et.GetHess());
|
||||
}
|
||||
@@ -259,7 +467,7 @@ class HistBuilder {
|
||||
std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
|
||||
std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
|
||||
|
||||
size_t n_idx = 0;
|
||||
std::size_t n_idx = 0;
|
||||
for (auto const &c : valid_candidates) {
|
||||
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
||||
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
||||
@@ -275,7 +483,7 @@ class HistBuilder {
|
||||
n_idx++;
|
||||
}
|
||||
|
||||
size_t page_id{0};
|
||||
std::size_t page_id{0};
|
||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
||||
@@ -311,11 +519,12 @@ class HistBuilder {
|
||||
|
||||
/*! \brief construct a tree using quantized feature values */
|
||||
class QuantileHistMaker : public TreeUpdater {
|
||||
std::unique_ptr<HistBuilder> p_impl_;
|
||||
std::unique_ptr<HistBuilder> p_impl_{nullptr};
|
||||
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
|
||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||
std::make_shared<common::ColumnSampler>();
|
||||
common::Monitor monitor_;
|
||||
ObjInfo const *task_;
|
||||
ObjInfo const *task_{nullptr};
|
||||
|
||||
public:
|
||||
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
|
||||
@@ -332,7 +541,10 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
const std::vector<RegTree *> &trees) override {
|
||||
if (trees.front()->IsMultiTarget()) {
|
||||
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
if (!p_mtimpl_) {
|
||||
this->p_mtimpl_ = std::make_unique<MultiTargetHistBuilder>(
|
||||
ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_);
|
||||
}
|
||||
} else {
|
||||
if (!p_impl_) {
|
||||
p_impl_ =
|
||||
@@ -355,13 +567,14 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
|
||||
for (auto tree_it = trees.begin(); tree_it != trees.end(); ++tree_it) {
|
||||
if (need_copy()) {
|
||||
// Copy gradient into buffer for sampling.
|
||||
// Copy gradient into buffer for sampling. This converts C-order to F-order.
|
||||
std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out));
|
||||
}
|
||||
SampleGradient(ctx_, *param, h_sample_out);
|
||||
auto *h_out_position = &out_position[tree_it - trees.begin()];
|
||||
if ((*tree_it)->IsMultiTarget()) {
|
||||
LOG(FATAL) << "Not implemented.";
|
||||
UpdateTree<MultiExpandEntry>(&monitor_, h_sample_out, p_mtimpl_.get(), p_fmat, param,
|
||||
h_out_position, *tree_it);
|
||||
} else {
|
||||
UpdateTree<CPUExpandEntry>(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param,
|
||||
h_out_position, *tree_it);
|
||||
@@ -372,6 +585,9 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
||||
if (p_impl_) {
|
||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||
} else if (p_mtimpl_) {
|
||||
// Not yet supported.
|
||||
return false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
@@ -383,6 +599,6 @@ class QuantileHistMaker : public TreeUpdater {
|
||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||
.describe("Grow tree using quantized histogram.")
|
||||
.set_body([](Context const *ctx, ObjInfo const *task) {
|
||||
return new QuantileHistMaker(ctx, task);
|
||||
return new QuantileHistMaker{ctx, task};
|
||||
});
|
||||
} // namespace xgboost::tree
|
||||
|
||||
Reference in New Issue
Block a user