Avoid regenerating the gradient index for approx. (#7591)
This commit is contained in:
@@ -351,6 +351,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
|
||||
|
||||
auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
|
||||
auto const& Stats() const { return snode_; }
|
||||
auto Task() const { return task_; }
|
||||
|
||||
float InitRoot(GradStats const& root_sum) {
|
||||
snode_.resize(1);
|
||||
|
||||
@@ -26,6 +26,19 @@ namespace tree {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(updater_approx);
|
||||
|
||||
namespace {
|
||||
// Return the BatchParam used by DMatrix.
|
||||
template <typename GradientSumT>
|
||||
auto BatchSpec(TrainParam const &p, common::Span<float> hess,
|
||||
HistEvaluator<GradientSumT, CPUExpandEntry> const &evaluator) {
|
||||
return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, !evaluator.Task().const_hess};
|
||||
}
|
||||
|
||||
auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
|
||||
return BatchParam{GenericParameter::kCpuId, p.max_bin, hess, false};
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename GradientSumT>
|
||||
class GloablApproxBuilder {
|
||||
protected:
|
||||
@@ -46,12 +59,13 @@ class GloablApproxBuilder {
|
||||
public:
|
||||
void InitData(DMatrix *p_fmat, common::Span<float> hess) {
|
||||
monitor_->Start(__func__);
|
||||
|
||||
n_batches_ = 0;
|
||||
int32_t n_total_bins = 0;
|
||||
partitioner_.clear();
|
||||
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(
|
||||
{GenericParameter::kCpuId, param_.max_bin, hess, true})) {
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, evaluator_))) {
|
||||
if (n_total_bins == 0) {
|
||||
n_total_bins = page.cut.TotalBins();
|
||||
feature_values_ = page.cut;
|
||||
@@ -62,9 +76,8 @@ class GloablApproxBuilder {
|
||||
n_batches_++;
|
||||
}
|
||||
|
||||
histogram_builder_.Reset(n_total_bins,
|
||||
BatchParam{GenericParameter::kCpuId, param_.max_bin, hess},
|
||||
ctx_->Threads(), n_batches_, rabit::IsDistributed());
|
||||
histogram_builder_.Reset(n_total_bins, BatchSpec(param_, hess), ctx_->Threads(), n_batches_,
|
||||
rabit::IsDistributed());
|
||||
monitor_->Stop(__func__);
|
||||
}
|
||||
|
||||
@@ -82,8 +95,7 @@ class GloablApproxBuilder {
|
||||
std::vector<CPUExpandEntry> nodes{best};
|
||||
size_t i = 0;
|
||||
auto space = this->ConstructHistSpace(nodes);
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
|
||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes,
|
||||
{}, gpair);
|
||||
i++;
|
||||
@@ -175,8 +187,7 @@ class GloablApproxBuilder {
|
||||
|
||||
size_t i = 0;
|
||||
auto space = this->ConstructHistSpace(nodes_to_build);
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
|
||||
histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||
nodes_to_build, nodes_to_sub, gpair);
|
||||
i++;
|
||||
@@ -225,8 +236,7 @@ class GloablApproxBuilder {
|
||||
|
||||
monitor_->Start("UpdatePosition");
|
||||
size_t i = 0;
|
||||
for (auto const &page :
|
||||
p_fmat->GetBatches<GHistIndexMatrix>({GenericParameter::kCpuId, param_.max_bin, hess})) {
|
||||
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess))) {
|
||||
partitioner_.at(i).UpdatePosition(ctx_, page, applied, p_tree);
|
||||
i++;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user