Avoid omp reduction in coordinate descent and aft metrics. (#7316)
Aside from the omp issue, parameter configuration for aft metric is simplified.
This commit is contained in:
parent
f56e2e9a66
commit
fb1a9e6bc5
@ -108,27 +108,32 @@ inline std::pair<double, double> GetGradient(int group_idx, int num_group, int f
|
|||||||
*
|
*
|
||||||
* \return The gradient and diagonal Hessian entry for a given feature.
|
* \return The gradient and diagonal Hessian entry for a given feature.
|
||||||
*/
|
*/
|
||||||
inline std::pair<double, double> GetGradientParallel(int group_idx, int num_group, int fidx,
|
inline std::pair<double, double>
|
||||||
const std::vector<GradientPair> &gpair,
|
GetGradientParallel(GenericParameter const *ctx, int group_idx, int num_group,
|
||||||
|
int fidx, const std::vector<GradientPair> &gpair,
|
||||||
DMatrix *p_fmat) {
|
DMatrix *p_fmat) {
|
||||||
double sum_grad = 0.0, sum_hess = 0.0;
|
std::vector<double> sum_grad_tloc(ctx->Threads(), 0.0);
|
||||||
|
std::vector<double> sum_hess_tloc(ctx->Threads(), 0.0);
|
||||||
|
|
||||||
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
for (const auto &batch : p_fmat->GetBatches<CSCPage>()) {
|
||||||
auto page = batch.GetView();
|
auto page = batch.GetView();
|
||||||
auto col = page[fidx];
|
auto col = page[fidx];
|
||||||
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
const auto ndata = static_cast<bst_omp_uint>(col.size());
|
||||||
dmlc::OMPException exc;
|
common::ParallelFor(ndata, ctx->Threads(), [&](size_t j) {
|
||||||
#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess)
|
|
||||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
|
||||||
exc.Run([&]() {
|
|
||||||
const bst_float v = col[j].fvalue;
|
const bst_float v = col[j].fvalue;
|
||||||
auto &p = gpair[col[j].index * num_group + group_idx];
|
auto &p = gpair[col[j].index * num_group + group_idx];
|
||||||
if (p.GetHess() < 0.0f) return;
|
if (p.GetHess() < 0.0f) {
|
||||||
sum_grad += p.GetGrad() * v;
|
return;
|
||||||
sum_hess += p.GetHess() * v * v;
|
}
|
||||||
|
auto t_idx = omp_get_thread_num();
|
||||||
|
sum_grad_tloc[t_idx] += p.GetGrad() * v;
|
||||||
|
sum_hess_tloc[t_idx] += p.GetHess() * v * v;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
exc.Rethrow();
|
double sum_grad =
|
||||||
}
|
std::accumulate(sum_grad_tloc.cbegin(), sum_grad_tloc.cend(), 0.0);
|
||||||
|
double sum_hess =
|
||||||
|
std::accumulate(sum_hess_tloc.cbegin(), sum_hess_tloc.cend(), 0.0);
|
||||||
return std::make_pair(sum_grad, sum_hess);
|
return std::make_pair(sum_grad, sum_hess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -80,8 +80,8 @@ class CoordinateUpdater : public LinearUpdater {
|
|||||||
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
DMatrix *p_fmat, gbm::GBLinearModel *model) {
|
||||||
const int ngroup = model->learner_model_param->num_output_group;
|
const int ngroup = model->learner_model_param->num_output_group;
|
||||||
bst_float &w = (*model)[fidx][group_idx];
|
bst_float &w = (*model)[fidx][group_idx];
|
||||||
auto gradient =
|
auto gradient = GetGradientParallel(learner_param_, group_idx, ngroup, fidx,
|
||||||
GetGradientParallel(group_idx, ngroup, fidx, *in_gpair, p_fmat);
|
*in_gpair, p_fmat);
|
||||||
auto dw = static_cast<float>(
|
auto dw = static_cast<float>(
|
||||||
tparam_.learning_rate *
|
tparam_.learning_rate *
|
||||||
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,
|
CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm,
|
||||||
|
|||||||
@ -18,6 +18,7 @@
|
|||||||
#include "metric_common.h"
|
#include "metric_common.h"
|
||||||
#include "../common/math.h"
|
#include "../common/math.h"
|
||||||
#include "../common/survival_util.h"
|
#include "../common/survival_util.h"
|
||||||
|
#include "../common/threading_utils.h"
|
||||||
|
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
#include <thrust/execution_policy.h> // thrust::cuda::par
|
#include <thrust/execution_policy.h> // thrust::cuda::par
|
||||||
@ -42,11 +43,12 @@ class ElementWiseSurvivalMetricsReduction {
|
|||||||
policy_ = policy;
|
policy_ = policy;
|
||||||
}
|
}
|
||||||
|
|
||||||
PackedReduceResult CpuReduceMetrics(
|
PackedReduceResult
|
||||||
const HostDeviceVector<bst_float>& weights,
|
CpuReduceMetrics(const HostDeviceVector<bst_float> &weights,
|
||||||
const HostDeviceVector<bst_float>& labels_lower_bound,
|
const HostDeviceVector<bst_float> &labels_lower_bound,
|
||||||
const HostDeviceVector<bst_float>& labels_upper_bound,
|
const HostDeviceVector<bst_float> &labels_upper_bound,
|
||||||
const HostDeviceVector<bst_float>& preds) const {
|
const HostDeviceVector<bst_float> &preds,
|
||||||
|
int32_t n_threads) const {
|
||||||
size_t ndata = labels_lower_bound.Size();
|
size_t ndata = labels_lower_bound.Size();
|
||||||
CHECK_EQ(ndata, labels_upper_bound.Size());
|
CHECK_EQ(ndata, labels_upper_bound.Size());
|
||||||
|
|
||||||
@ -55,22 +57,24 @@ class ElementWiseSurvivalMetricsReduction {
|
|||||||
const auto& h_weights = weights.HostVector();
|
const auto& h_weights = weights.HostVector();
|
||||||
const auto& h_preds = preds.HostVector();
|
const auto& h_preds = preds.HostVector();
|
||||||
|
|
||||||
double residue_sum = 0;
|
std::vector<double> score_tloc(n_threads, 0.0);
|
||||||
double weights_sum = 0;
|
std::vector<double> weight_tloc(n_threads, 0.0);
|
||||||
|
|
||||||
dmlc::OMPException exc;
|
common::ParallelFor(ndata, n_threads, [&](size_t i) {
|
||||||
#pragma omp parallel for reduction(+: residue_sum, weights_sum) schedule(static)
|
const double wt =
|
||||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
h_weights.empty() ? 1.0 : static_cast<double>(h_weights[i]);
|
||||||
exc.Run([&]() {
|
auto t_idx = omp_get_thread_num();
|
||||||
const double wt = h_weights.empty() ? 1.0 : static_cast<double>(h_weights[i]);
|
score_tloc[t_idx] +=
|
||||||
residue_sum += policy_.EvalRow(
|
policy_.EvalRow(static_cast<double>(h_labels_lower_bound[i]),
|
||||||
static_cast<double>(h_labels_lower_bound[i]),
|
|
||||||
static_cast<double>(h_labels_upper_bound[i]),
|
static_cast<double>(h_labels_upper_bound[i]),
|
||||||
static_cast<double>(h_preds[i])) * wt;
|
static_cast<double>(h_preds[i])) *
|
||||||
weights_sum += wt;
|
wt;
|
||||||
|
weight_tloc[t_idx] += wt;
|
||||||
});
|
});
|
||||||
}
|
|
||||||
exc.Rethrow();
|
double residue_sum = std::accumulate(score_tloc.cbegin(), score_tloc.cend(), 0.0);
|
||||||
|
double weights_sum = std::accumulate(weight_tloc.cbegin(), weight_tloc.cend(), 0.0);
|
||||||
|
|
||||||
PackedReduceResult res{residue_sum, weights_sum};
|
PackedReduceResult res{residue_sum, weights_sum};
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -119,25 +123,25 @@ class ElementWiseSurvivalMetricsReduction {
|
|||||||
#endif // XGBOOST_USE_CUDA
|
#endif // XGBOOST_USE_CUDA
|
||||||
|
|
||||||
PackedReduceResult Reduce(
|
PackedReduceResult Reduce(
|
||||||
int device,
|
const GenericParameter &ctx,
|
||||||
const HostDeviceVector<bst_float>& weights,
|
const HostDeviceVector<bst_float>& weights,
|
||||||
const HostDeviceVector<bst_float>& labels_lower_bound,
|
const HostDeviceVector<bst_float>& labels_lower_bound,
|
||||||
const HostDeviceVector<bst_float>& labels_upper_bound,
|
const HostDeviceVector<bst_float>& labels_upper_bound,
|
||||||
const HostDeviceVector<bst_float>& preds) {
|
const HostDeviceVector<bst_float>& preds) {
|
||||||
PackedReduceResult result;
|
PackedReduceResult result;
|
||||||
|
|
||||||
if (device < 0) {
|
if (ctx.gpu_id < 0) {
|
||||||
result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds);
|
result = CpuReduceMetrics(weights, labels_lower_bound, labels_upper_bound,
|
||||||
|
preds, ctx.Threads());
|
||||||
}
|
}
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
else { // NOLINT
|
else { // NOLINT
|
||||||
device_ = device;
|
preds.SetDevice(ctx.gpu_id);
|
||||||
preds.SetDevice(device_);
|
labels_lower_bound.SetDevice(ctx.gpu_id);
|
||||||
labels_lower_bound.SetDevice(device_);
|
labels_upper_bound.SetDevice(ctx.gpu_id);
|
||||||
labels_upper_bound.SetDevice(device_);
|
weights.SetDevice(ctx.gpu_id);
|
||||||
weights.SetDevice(device_);
|
|
||||||
|
|
||||||
dh::safe_cuda(cudaSetDevice(device_));
|
dh::safe_cuda(cudaSetDevice(ctx.gpu_id));
|
||||||
result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds);
|
result = DeviceReduceMetrics(weights, labels_lower_bound, labels_upper_bound, preds);
|
||||||
}
|
}
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
@ -146,9 +150,6 @@ class ElementWiseSurvivalMetricsReduction {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
EvalRow policy_;
|
EvalRow policy_;
|
||||||
#if defined(XGBOOST_USE_CUDA)
|
|
||||||
int device_{-1};
|
|
||||||
#endif // defined(XGBOOST_USE_CUDA)
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct EvalIntervalRegressionAccuracy {
|
struct EvalIntervalRegressionAccuracy {
|
||||||
@ -193,18 +194,16 @@ struct EvalAFTNLogLik {
|
|||||||
AFTParam param_;
|
AFTParam param_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Policy>
|
template <typename Policy> struct EvalEWiseSurvivalBase : public Metric {
|
||||||
struct EvalEWiseSurvivalBase : public Metric {
|
explicit EvalEWiseSurvivalBase(GenericParameter const *ctx) {
|
||||||
|
tparam_ = ctx;
|
||||||
|
}
|
||||||
EvalEWiseSurvivalBase() = default;
|
EvalEWiseSurvivalBase() = default;
|
||||||
|
|
||||||
void Configure(const Args& args) override {
|
void Configure(const Args& args) override {
|
||||||
policy_.Configure(args);
|
policy_.Configure(args);
|
||||||
for (const auto& e : args) {
|
|
||||||
if (e.first == "gpu_id") {
|
|
||||||
device_ = dmlc::ParseSignedInt<int>(e.second.c_str(), nullptr, 10);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
reducer_.Configure(policy_);
|
reducer_.Configure(policy_);
|
||||||
|
CHECK(tparam_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||||
@ -212,9 +211,10 @@ struct EvalEWiseSurvivalBase : public Metric {
|
|||||||
bool distributed) override {
|
bool distributed) override {
|
||||||
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
||||||
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
||||||
|
CHECK(tparam_);
|
||||||
auto result = reducer_.Reduce(
|
auto result =
|
||||||
device_, info.weights_, info.labels_lower_bound_, info.labels_upper_bound_, preds);
|
reducer_.Reduce(*tparam_, info.weights_, info.labels_lower_bound_,
|
||||||
|
info.labels_upper_bound_, preds);
|
||||||
|
|
||||||
double dat[2] {result.Residue(), result.Weights()};
|
double dat[2] {result.Residue(), result.Weights()};
|
||||||
|
|
||||||
@ -252,24 +252,22 @@ struct AFTNLogLikDispatcher : public Metric {
|
|||||||
param_.UpdateAllowUnknown(args);
|
param_.UpdateAllowUnknown(args);
|
||||||
switch (param_.aft_loss_distribution) {
|
switch (param_.aft_loss_distribution) {
|
||||||
case common::ProbabilityDistributionType::kNormal:
|
case common::ProbabilityDistributionType::kNormal:
|
||||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>());
|
metric_.reset(
|
||||||
|
new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::NormalDistribution>>(
|
||||||
|
tparam_));
|
||||||
break;
|
break;
|
||||||
case common::ProbabilityDistributionType::kLogistic:
|
case common::ProbabilityDistributionType::kLogistic:
|
||||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::LogisticDistribution>>());
|
metric_.reset(new EvalEWiseSurvivalBase<
|
||||||
|
EvalAFTNLogLik<common::LogisticDistribution>>(tparam_));
|
||||||
break;
|
break;
|
||||||
case common::ProbabilityDistributionType::kExtreme:
|
case common::ProbabilityDistributionType::kExtreme:
|
||||||
metric_.reset(new EvalEWiseSurvivalBase<EvalAFTNLogLik<common::ExtremeDistribution>>());
|
metric_.reset(new EvalEWiseSurvivalBase<
|
||||||
|
EvalAFTNLogLik<common::ExtremeDistribution>>(tparam_));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
LOG(FATAL) << "Unknown probability distribution";
|
LOG(FATAL) << "Unknown probability distribution";
|
||||||
}
|
}
|
||||||
Args new_args{args};
|
metric_->Configure(args);
|
||||||
// tparam_ doesn't get propagated to the inner metric object because we didn't use
|
|
||||||
// Metric::Create(). I don't think it's a good idea to pollute the metric registry with
|
|
||||||
// specialized versions of the AFT metric, so as a work-around, manually pass the GPU ID
|
|
||||||
// into the inner metric via configuration.
|
|
||||||
new_args.emplace_back("gpu_id", std::to_string(tparam_->gpu_id));
|
|
||||||
metric_->Configure(new_args);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void SaveConfig(Json* p_out) const override {
|
void SaveConfig(Json* p_out) const override {
|
||||||
|
|||||||
@ -11,6 +11,41 @@
|
|||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
|
namespace {
|
||||||
|
inline void CheckDeterministicMetricElementWise(StringView name, int32_t device) {
|
||||||
|
auto lparam = CreateEmptyGenericParam(device);
|
||||||
|
std::unique_ptr<Metric> metric{Metric::Create(name.c_str(), &lparam)};
|
||||||
|
metric->Configure(Args{});
|
||||||
|
|
||||||
|
HostDeviceVector<float> predts;
|
||||||
|
MetaInfo info;
|
||||||
|
auto &h_predts = predts.HostVector();
|
||||||
|
|
||||||
|
SimpleLCG lcg;
|
||||||
|
SimpleRealUniformDistribution<float> dist{0.0f, 1.0f};
|
||||||
|
|
||||||
|
size_t n_samples = 2048;
|
||||||
|
h_predts.resize(n_samples);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
h_predts[i] = dist(&lcg);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto &h_upper = info.labels_upper_bound_.HostVector();
|
||||||
|
auto &h_lower = info.labels_lower_bound_.HostVector();
|
||||||
|
h_lower.resize(n_samples);
|
||||||
|
h_upper.resize(n_samples);
|
||||||
|
for (size_t i = 0; i < n_samples; ++i) {
|
||||||
|
h_lower[i] = 1;
|
||||||
|
h_upper[i] = 10;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = metric->Eval(predts, info, false);
|
||||||
|
for (size_t i = 0; i < 8; ++i) {
|
||||||
|
ASSERT_EQ(metric->Eval(predts, info, false), result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
|
TEST(Metric, DeclareUnifiedTest(AFTNegLogLik)) {
|
||||||
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
auto lparam = xgboost::CreateEmptyGenericParam(GPUIDX);
|
||||||
@ -61,6 +96,8 @@ TEST(Metric, DeclareUnifiedTest(IntervalRegressionAccuracy)) {
|
|||||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info, false), 0.50f);
|
EXPECT_FLOAT_EQ(metric->Eval(preds, info, false), 0.50f);
|
||||||
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
info.labels_lower_bound_.HostVector()[0] = 70.0f;
|
||||||
EXPECT_FLOAT_EQ(metric->Eval(preds, info, false), 0.25f);
|
EXPECT_FLOAT_EQ(metric->Eval(preds, info, false), 0.25f);
|
||||||
|
|
||||||
|
CheckDeterministicMetricElementWise(StringView{"interval-regression-accuracy"}, GPUIDX);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test configuration of AFT metric
|
// Test configuration of AFT metric
|
||||||
@ -75,6 +112,8 @@ TEST(AFTNegLogLikMetric, DeclareUnifiedTest(Configuration)) {
|
|||||||
auto aft_param_json = j_obj["aft_loss_param"];
|
auto aft_param_json = j_obj["aft_loss_param"];
|
||||||
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution"]), "normal");
|
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution"]), "normal");
|
||||||
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution_scale"]), "10");
|
EXPECT_EQ(get<String>(aft_param_json["aft_loss_distribution_scale"]), "10");
|
||||||
|
|
||||||
|
CheckDeterministicMetricElementWise(StringView{"aft-nloglik"}, GPUIDX);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace common
|
} // namespace common
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user