More collective aggregators (#9060)
This commit is contained in:
parent
7032981350
commit
8dbe0510de
@ -8,6 +8,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
#include <xgboost/data.h>
|
#include <xgboost/data.h>
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -57,5 +58,72 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu
|
|||||||
std::forward<Function>(function)(std::forward<Args>(args)...);
|
std::forward<Function>(function)(std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the global max of the given value across all workers.
|
||||||
|
*
|
||||||
|
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||||
|
* column-wise (vertically), the local value is returned.
|
||||||
|
*
|
||||||
|
* @tparam T The type of the value.
|
||||||
|
* @param info MetaInfo about the DMatrix.
|
||||||
|
* @param value The input for finding the global max.
|
||||||
|
* @return The global max of the input.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
T GlobalMax(MetaInfo const& info, T value) {
|
||||||
|
if (info.IsRowSplit()) {
|
||||||
|
collective::Allreduce<collective::Operation::kMax>(&value, 1);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the global sum of the given values across all workers.
|
||||||
|
*
|
||||||
|
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||||
|
* column-wise (vertically), the original values are returned.
|
||||||
|
*
|
||||||
|
* @tparam T The type of the values.
|
||||||
|
* @param info MetaInfo about the DMatrix.
|
||||||
|
* @param values Pointer to the inputs to sum.
|
||||||
|
* @param size Number of values to sum.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void GlobalSum(MetaInfo const& info, T* values, size_t size) {
|
||||||
|
if (info.IsRowSplit()) {
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(values, size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Container>
|
||||||
|
void GlobalSum(MetaInfo const& info, Container* values) {
|
||||||
|
GlobalSum(info, values->data(), values->size());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Find the global ratio of the given two values across all workers.
|
||||||
|
*
|
||||||
|
* This only applies when the data is split row-wise (horizontally). When data is split
|
||||||
|
* column-wise (vertically), the local ratio is returned.
|
||||||
|
*
|
||||||
|
* @tparam T The type of the values.
|
||||||
|
* @param info MetaInfo about the DMatrix.
|
||||||
|
* @param dividend The dividend of the ratio.
|
||||||
|
* @param divisor The divisor of the ratio.
|
||||||
|
* @return The global ratio of the two inputs.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
|
||||||
|
std::array<T, 2> results{dividend, divisor};
|
||||||
|
GlobalSum(info, &results);
|
||||||
|
std::tie(dividend, divisor) = std::tuple_cat(results);
|
||||||
|
if (divisor <= 0) {
|
||||||
|
return std::numeric_limits<T>::quiet_NaN();
|
||||||
|
} else {
|
||||||
|
return dividend / divisor;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace collective
|
} // namespace collective
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -116,10 +116,7 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI
|
|||||||
|
|
||||||
// we have 2 averages going in here, first is among workers, second is among
|
// we have 2 averages going in here, first is among workers, second is among
|
||||||
// classes. allreduce sums up fp/tp auc for each class.
|
// classes. allreduce sums up fp/tp auc for each class.
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &results.Values());
|
||||||
collective::Allreduce<collective::Operation::kSum>(results.Values().data(),
|
|
||||||
results.Values().size());
|
|
||||||
}
|
|
||||||
double auc_sum{0};
|
double auc_sum{0};
|
||||||
double tp_sum{0};
|
double tp_sum{0};
|
||||||
for (size_t c = 0; c < n_classes; ++c) {
|
for (size_t c = 0; c < n_classes; ++c) {
|
||||||
@ -293,17 +290,8 @@ class EvalAUC : public MetricNoCache {
|
|||||||
InvalidGroupAUC();
|
InvalidGroupAUC();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::array<double, 2> results{auc, static_cast<double>(valid_groups)};
|
auc = collective::GlobalRatio(info, auc, static_cast<double>(valid_groups));
|
||||||
if (info.IsRowSplit()) {
|
if (!std::isnan(auc)) {
|
||||||
collective::Allreduce<collective::Operation::kSum>(results.data(), results.size());
|
|
||||||
}
|
|
||||||
auc = results[0];
|
|
||||||
valid_groups = static_cast<uint32_t>(results[1]);
|
|
||||||
|
|
||||||
if (valid_groups <= 0) {
|
|
||||||
auc = std::numeric_limits<double>::quiet_NaN();
|
|
||||||
} else {
|
|
||||||
auc /= valid_groups;
|
|
||||||
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
|
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
|
||||||
<< ", valid groups: " << valid_groups;
|
<< ", valid groups: " << valid_groups;
|
||||||
}
|
}
|
||||||
@ -323,19 +311,9 @@ class EvalAUC : public MetricNoCache {
|
|||||||
std::tie(fp, tp, auc) =
|
std::tie(fp, tp, auc) =
|
||||||
static_cast<Curve *>(this)->EvalBinary(preds, info);
|
static_cast<Curve *>(this)->EvalBinary(preds, info);
|
||||||
}
|
}
|
||||||
double local_area = fp * tp;
|
auc = collective::GlobalRatio(info, auc, fp * tp);
|
||||||
std::array<double, 2> result{auc, local_area};
|
if (!std::isnan(auc)) {
|
||||||
if (info.IsRowSplit()) {
|
CHECK_LE(auc, 1.0);
|
||||||
collective::Allreduce<collective::Operation::kSum>(result.data(), result.size());
|
|
||||||
}
|
|
||||||
std::tie(auc, local_area) = common::UnpackArr(std::move(result));
|
|
||||||
if (local_area <= 0) {
|
|
||||||
// the dataset across all workers have only positive or negative sample
|
|
||||||
auc = std::numeric_limits<double>::quiet_NaN();
|
|
||||||
} else {
|
|
||||||
CHECK_LE(auc, local_area);
|
|
||||||
// normalization
|
|
||||||
auc = auc / local_area;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (std::isnan(auc)) {
|
if (std::isnan(auc)) {
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
@ -197,10 +198,8 @@ class PseudoErrorLoss : public MetricNoCache {
|
|||||||
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
|
auto v = common::Sqr(slope) * (std::sqrt((1 + common::Sqr(a / slope))) - 1) * wt;
|
||||||
return std::make_tuple(v, wt);
|
return std::make_tuple(v, wt);
|
||||||
});
|
});
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
std::array<double, 2> dat{result.Residue(), result.Weights()};
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
}
|
|
||||||
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
|
return EvalRowMAPE::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -366,10 +365,8 @@ struct EvalEWiseBase : public MetricNoCache {
|
|||||||
return std::make_tuple(residue, wt);
|
return std::make_tuple(residue, wt);
|
||||||
});
|
});
|
||||||
|
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
std::array<double, 2> dat{result.Residue(), result.Weights()};
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
}
|
|
||||||
return Policy::GetFinal(dat[0], dat[1]);
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -440,10 +437,8 @@ class QuantileError : public MetricNoCache {
|
|||||||
CHECK(!alpha_.Empty());
|
CHECK(!alpha_.Empty());
|
||||||
if (info.num_row_ == 0) {
|
if (info.num_row_ == 0) {
|
||||||
// empty DMatrix on distributed env
|
// empty DMatrix on distributed env
|
||||||
double dat[2]{0.0, 0.0};
|
std::array<double, 2> dat{0.0, 0.0};
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
}
|
|
||||||
CHECK_GT(dat[1], 0);
|
CHECK_GT(dat[1], 0);
|
||||||
return dat[0] / dat[1];
|
return dat[0] / dat[1];
|
||||||
}
|
}
|
||||||
@ -480,10 +475,8 @@ class QuantileError : public MetricNoCache {
|
|||||||
loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w;
|
loss(y_predt(sample_id, quantile_id, target_id), y_true(sample_id, target_id)) * w;
|
||||||
return std::make_tuple(l, w);
|
return std::make_tuple(l, w);
|
||||||
});
|
});
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
std::array<double, 2> dat{result.Residue(), result.Weights()};
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
}
|
|
||||||
CHECK_GT(dat[1], 0);
|
CHECK_GT(dat[1], 0);
|
||||||
return dat[0] / dat[1];
|
return dat[0] / dat[1];
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
*/
|
*/
|
||||||
#include <xgboost/metric.h>
|
#include <xgboost/metric.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
@ -169,7 +170,7 @@ struct EvalMClassBase : public MetricNoCache {
|
|||||||
} else {
|
} else {
|
||||||
CHECK(preds.Size() % info.labels.Size() == 0) << "label and prediction size not match";
|
CHECK(preds.Size() % info.labels.Size() == 0) << "label and prediction size not match";
|
||||||
}
|
}
|
||||||
double dat[2] { 0.0, 0.0 };
|
std::array<double, 2> dat{0.0, 0.0};
|
||||||
if (info.labels.Size() != 0) {
|
if (info.labels.Size() != 0) {
|
||||||
const size_t nclass = preds.Size() / info.labels.Size();
|
const size_t nclass = preds.Size() / info.labels.Size();
|
||||||
CHECK_GE(nclass, 1U)
|
CHECK_GE(nclass, 1U)
|
||||||
@ -181,9 +182,7 @@ struct EvalMClassBase : public MetricNoCache {
|
|||||||
dat[0] = result.Residue();
|
dat[0] = result.Residue();
|
||||||
dat[1] = result.Weights();
|
dat[1] = result.Weights();
|
||||||
}
|
}
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
}
|
|
||||||
return Derived::GetFinal(dat[0], dat[1]);
|
return Derived::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
/*!
|
/*!
|
||||||
|
|||||||
@ -238,14 +238,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig {
|
|||||||
exc.Rethrow();
|
exc.Rethrow();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (collective::IsDistributed() && info.IsRowSplit()) {
|
return collective::GlobalRatio(info, sum_metric, static_cast<double>(ngroups));
|
||||||
double dat[2]{sum_metric, static_cast<double>(ngroups)};
|
|
||||||
// approximately estimate the metric using mean
|
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
return dat[0] / dat[1];
|
|
||||||
} else {
|
|
||||||
return sum_metric / ngroups;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* Name() const override {
|
const char* Name() const override {
|
||||||
@ -401,9 +394,8 @@ class EvalRankWithCache : public Metric {
|
|||||||
namespace {
|
namespace {
|
||||||
double Finalize(MetaInfo const& info, double score, double sw) {
|
double Finalize(MetaInfo const& info, double score, double sw) {
|
||||||
std::array<double, 2> dat{score, sw};
|
std::array<double, 2> dat{score, sw};
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat.data(), dat.size());
|
std::tie(score, sw) = std::tuple_cat(dat);
|
||||||
}
|
|
||||||
if (sw > 0.0) {
|
if (sw > 0.0) {
|
||||||
score = score / sw;
|
score = score / sw;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
#include <dmlc/registry.h>
|
#include <dmlc/registry.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -211,10 +212,8 @@ struct EvalEWiseSurvivalBase : public MetricNoCache {
|
|||||||
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
|
auto result = reducer_.Reduce(*ctx_, info.weights_, info.labels_lower_bound_,
|
||||||
info.labels_upper_bound_, preds);
|
info.labels_upper_bound_, preds);
|
||||||
|
|
||||||
double dat[2]{result.Residue(), result.Weights()};
|
std::array<double, 2> dat{result.Residue(), result.Weights()};
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(dat, 2);
|
|
||||||
}
|
|
||||||
return Policy::GetFinal(dat[0], dat[1]);
|
return Policy::GetFinal(dat[0], dat[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -6,8 +6,9 @@
|
|||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector> // std::vector
|
#include <vector> // std::vector
|
||||||
|
|
||||||
|
#include "../collective/aggregator.h"
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/common.h"
|
#include "../common/common.h"
|
||||||
#include "xgboost/base.h" // bst_node_t
|
#include "xgboost/base.h" // bst_node_t
|
||||||
@ -41,10 +42,7 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
auto& quantiles = *p_quantiles;
|
auto& quantiles = *p_quantiles;
|
||||||
auto const& h_node_idx = nidx;
|
auto const& h_node_idx = nidx;
|
||||||
|
|
||||||
size_t n_leaf{h_node_idx.size()};
|
size_t n_leaf = collective::GlobalMax(info, h_node_idx.size());
|
||||||
if (info.IsRowSplit()) {
|
|
||||||
collective::Allreduce<collective::Operation::kMax>(&n_leaf, 1);
|
|
||||||
}
|
|
||||||
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
CHECK(quantiles.empty() || quantiles.size() == n_leaf);
|
||||||
if (quantiles.empty()) {
|
if (quantiles.empty()) {
|
||||||
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
quantiles.resize(n_leaf, std::numeric_limits<float>::quiet_NaN());
|
||||||
@ -54,16 +52,12 @@ inline void UpdateLeafValues(std::vector<float>* p_quantiles, std::vector<bst_no
|
|||||||
std::vector<int32_t> n_valids(quantiles.size());
|
std::vector<int32_t> n_valids(quantiles.size());
|
||||||
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
std::transform(quantiles.cbegin(), quantiles.cend(), n_valids.begin(),
|
||||||
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
[](float q) { return static_cast<int32_t>(!std::isnan(q)); });
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &n_valids);
|
||||||
collective::Allreduce<collective::Operation::kSum>(n_valids.data(), n_valids.size());
|
|
||||||
}
|
|
||||||
// convert to 0 for all reduce
|
// convert to 0 for all reduce
|
||||||
std::replace_if(
|
std::replace_if(
|
||||||
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
quantiles.begin(), quantiles.end(), [](float q) { return std::isnan(q); }, 0.f);
|
||||||
// use the mean value
|
// use the mean value
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &quantiles);
|
||||||
collective::Allreduce<collective::Operation::kSum>(quantiles.data(), quantiles.size());
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < n_leaf; ++i) {
|
for (size_t i = 0; i < n_leaf; ++i) {
|
||||||
if (n_valids[i] > 0) {
|
if (n_valids[i] > 0) {
|
||||||
quantiles[i] /= static_cast<float>(n_valids[i]);
|
quantiles[i] /= static_cast<float>(n_valids[i]);
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* Copyright 2023 by XGBoost contributors
|
* Copyright 2023 by XGBoost contributors
|
||||||
*/
|
*/
|
||||||
|
#include <array> // std::array
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // std::size_t
|
||||||
#include <cstdint> // std::int32_t
|
#include <cstdint> // std::int32_t
|
||||||
#include <vector> // std::vector
|
#include <vector> // std::vector
|
||||||
@ -170,10 +171,9 @@ class QuantileRegression : public ObjFunction {
|
|||||||
common::Mean(ctx_, *base_score, &temp);
|
common::Mean(ctx_, *base_score, &temp);
|
||||||
double meanq = temp(0) * sw;
|
double meanq = temp(0) * sw;
|
||||||
|
|
||||||
if (info.IsRowSplit()) {
|
std::array<double, 2> dat{meanq, sw};
|
||||||
collective::Allreduce<collective::Operation::kSum>(&meanq, 1);
|
collective::GlobalSum(info, &dat);
|
||||||
collective::Allreduce<collective::Operation::kSum>(&sw, 1);
|
std::tie(meanq, sw) = std::tuple_cat(dat);
|
||||||
}
|
|
||||||
meanq /= (sw + kRtEps);
|
meanq /= (sw + kRtEps);
|
||||||
base_score->Reshape(1);
|
base_score->Reshape(1);
|
||||||
base_score->Data()->Fill(meanq);
|
base_score->Data()->Fill(meanq);
|
||||||
|
|||||||
@ -728,10 +728,8 @@ class MeanAbsoluteError : public ObjFunction {
|
|||||||
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
std::transform(linalg::cbegin(out), linalg::cend(out), linalg::begin(out),
|
||||||
[w](float v) { return v * w; });
|
[w](float v) { return v * w; });
|
||||||
|
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, &out.Values());
|
||||||
collective::Allreduce<collective::Operation::kSum>(out.Values().data(), out.Values().size());
|
collective::GlobalSum(info, &w, 1);
|
||||||
collective::Allreduce<collective::Operation::kSum>(&w, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (common::CloseTo(w, 0.0)) {
|
if (common::CloseTo(w, 0.0)) {
|
||||||
// Mostly for handling empty dataset test.
|
// Mostly for handling empty dataset test.
|
||||||
|
|||||||
@ -8,6 +8,7 @@
|
|||||||
#include <cinttypes> // std::int32_t
|
#include <cinttypes> // std::int32_t
|
||||||
#include <cstddef> // std::size_t
|
#include <cstddef> // std::size_t
|
||||||
|
|
||||||
|
#include "../collective/aggregator.h"
|
||||||
#include "../collective/communicator-inl.h"
|
#include "../collective/communicator-inl.h"
|
||||||
#include "../common/common.h" // AssertGPUSupport
|
#include "../common/common.h" // AssertGPUSupport
|
||||||
#include "../common/numeric.h" // cpu_impl::Reduce
|
#include "../common/numeric.h" // cpu_impl::Reduce
|
||||||
@ -45,10 +46,7 @@ void FitStump(Context const* ctx, MetaInfo const& info,
|
|||||||
}
|
}
|
||||||
CHECK(h_sum.CContiguous());
|
CHECK(h_sum.CContiguous());
|
||||||
|
|
||||||
if (info.IsRowSplit()) {
|
collective::GlobalSum(info, reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
||||||
collective::Allreduce<collective::Operation::kSum>(
|
|
||||||
reinterpret_cast<double*>(h_sum.Values().data()), h_sum.Size() * 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
for (std::size_t i = 0; i < h_sum.Size(); ++i) {
|
||||||
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
out(i) = static_cast<float>(CalcUnregularizedWeight(h_sum(i).GetGrad(), h_sum(i).GetHess()));
|
||||||
|
|||||||
@ -7,6 +7,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "../collective/aggregator.h"
|
||||||
#include "../common/random.h"
|
#include "../common/random.h"
|
||||||
#include "../data/gradient_index.h"
|
#include "../data/gradient_index.h"
|
||||||
#include "common_row_partitioner.h"
|
#include "common_row_partitioner.h"
|
||||||
@ -92,9 +93,7 @@ class GloablApproxBuilder {
|
|||||||
for (auto const &g : gpair) {
|
for (auto const &g : gpair) {
|
||||||
root_sum.Add(g);
|
root_sum.Add(g);
|
||||||
}
|
}
|
||||||
if (p_fmat->Info().IsRowSplit()) {
|
collective::GlobalSum(p_fmat->Info(), reinterpret_cast<double *>(&root_sum), 2);
|
||||||
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<double *>(&root_sum), 2);
|
|
||||||
}
|
|
||||||
std::vector<CPUExpandEntry> nodes{best};
|
std::vector<CPUExpandEntry> nodes{best};
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes);
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user