[SYCL] Optimize gradients calculations. (#10325)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2024-06-08 05:53:23 +02:00
committed by GitHub
parent c9f5fcaf21
commit 0c44067736
3 changed files with 382 additions and 87 deletions

View File

@@ -22,7 +22,10 @@
#include "../../../src/objective/multiclass_param.h"
#include "../common/linalg_op.h"
#include "../device_manager.h"
#include "../data.h"
#include <CL/sycl.hpp>
namespace xgboost {
@@ -32,6 +35,15 @@ namespace obj {
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);
class SoftmaxMultiClassObj : public ObjFunction {
mutable bool are_buffs_init = false;
void InitBuffers(const std::vector<int>& sample_rate) const {
if (!are_buffs_init) {
batch_processor_.InitBuffers(&qu_, sample_rate);
are_buffs_init = true;
}
}
public:
explicit SoftmaxMultiClassObj(bool output_prob)
: output_prob_(output_prob) {}
@@ -44,7 +56,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
linalg::Matrix<GradientPair>* out_gpair) override {
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
if (preds.Size() == 0) return;
if (info.labels.Size() == 0) return;
@@ -66,54 +78,68 @@ class SoftmaxMultiClassObj : public ObjFunction {
<< "Number of weights should be equal to number of data points.";
}
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());
int flag = 1;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float const * point = &preds_acc[idx * nclass];
auto objective_fn = [=, &flag]
(const std::vector<::sycl::event>& events,
size_t ndata,
GradientPair* out_gpair,
const bst_float* preds,
const bst_float* labels,
const bst_float* weights) {
const size_t wg_size = 32;
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
[=] (size_t idx, auto flag) {
const bst_float* pred = preds + idx * nclass;
// Part of Softmax function
bst_float wmax = std::numeric_limits<bst_float>::min();
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
auto label = labels_acc[idx];
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(pred[k], wmax); }
bst_float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(pred[k] - wmax); }
bst_float label = labels[idx];
if (label < 0 || label >= nclass) {
flag_buf_acc[0] = 0;
AtomicRef<int> flag_ref(flag[0]);
flag_ref = 0;
label = 0;
}
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];
bst_float wt = is_null_weight ? 1.0f : weights[idx];
for (int k = 0; k < nclass; ++k) {
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
bst_float p = expf(pred[k] - wmax) / static_cast<float>(wsum);
const float eps = 1e-16f;
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
p = label == k ? p - 1.0f : p;
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
out_gpair[idx * nclass + k] = GradientPair(p * wt, h);
}
});
}).wait();
});
};
// out_gpair and preds have nclass points per sample
// labels and weights have 1 points per sample
InitBuffers({nclass, nclass, 1, 1});
if (is_null_weight) {
// Output is passed by pointer
// Inputs are passed by const reference
batch_processor_.Calculate(std::move(objective_fn),
out_gpair->Data(),
preds,
*(info.labels.Data()));
} else {
batch_processor_.Calculate(std::move(objective_fn),
out_gpair->Data(),
preds,
*(info.labels.Data()),
info.weights_);
}
// flag_buf is destroyed, content is copyed to the "flag"
qu_.wait_and_throw();
if (flag == 0) {
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
}
}
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
this->Transform(io_preds, output_prob_);
}
@@ -190,6 +216,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
sycl::DeviceManager device_manager;
mutable ::sycl::queue qu_;
static constexpr size_t kBatchSize = 1u << 22;
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
};
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")

View File

@@ -27,7 +27,10 @@
#pragma GCC diagnostic pop
#include "../../../src/objective/regression_param.h"
#include "../common/linalg_op.h"
#include "../device_manager.h"
#include "../data.h"
#include <CL/sycl.hpp>
@@ -41,6 +44,14 @@ template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
mutable bool are_buffs_init = false;
void InitBuffers() const {
if (!are_buffs_init) {
batch_processor_.InitBuffers(&qu_, {1, 1, 1, 1});
are_buffs_init = true;
}
}
public:
RegLossObj() = default;
@@ -53,63 +64,72 @@ class RegLossObj : public ObjFunction {
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
linalg::Matrix<GradientPair>* out_gpair) override {
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
<< "Loss: " << Loss::Name();
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
<< "Loss: " << Loss::Name();
size_t const ndata = preds.Size();
auto const n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);
size_t const ndata = preds.Size();
auto const n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);
// TODO(razdoburdin): add label_correct check
label_correct_.Resize(1);
label_correct_.Fill(1);
// TODO(razdoburdin): add label_correct check
label_correct_.Resize(1);
label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0;
bool is_null_weight = info.weights_.Size() == 0;
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());
auto scale_pos_weight = param_.scale_pos_weight;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}
auto scale_pos_weight = param_.scale_pos_weight;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}
int flag = 1;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float p = Loss::PredTransform(preds_acc[idx]);
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
bst_float label = labels_acc[idx];
int flag = 1;
auto objective_fn = [=, &flag]
(const std::vector<::sycl::event>& events,
size_t ndata,
GradientPair* out_gpair,
const bst_float* preds,
const bst_float* labels,
const bst_float* weights) {
const size_t wg_size = 32;
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
[=] (size_t idx, auto flag) {
const bst_float pred = Loss::PredTransform(preds[idx]);
bst_float weight = is_null_weight ? 1.0f : weights[idx/n_targets];
const bst_float label = labels[idx];
if (label == 1.0f) {
w *= scale_pos_weight;
weight *= scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
flag_buf_acc[0] = 0;
AtomicRef<int> flag_ref(flag[0]);
flag_ref = 0;
}
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
});
}).wait();
}
// flag_buf is destroyed, content is copyed to the "flag"
out_gpair[idx] = GradientPair(Loss::FirstOrderGradient(pred, label) * weight,
Loss::SecondOrderGradient(pred, label) * weight);
});
};
InitBuffers();
if (is_null_weight) {
// Output is passed by pointer
// Inputs are passed by const reference
batch_processor_.Calculate(std::move(objective_fn),
out_gpair->Data(),
preds,
*(info.labels.Data()));
} else {
batch_processor_.Calculate(std::move(objective_fn),
out_gpair->Data(),
preds,
*(info.labels.Data()),
info.weights_);
}
qu_.wait_and_throw();
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
@@ -121,18 +141,23 @@ class RegLossObj : public ObjFunction {
return Loss::DefaultEvalMetric();
}
void PredTransform(HostDeviceVector<float> *io_preds) const override {
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
size_t const ndata = io_preds->Size();
if (ndata == 0) return;
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
InitBuffers();
qu_.submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
batch_processor_.Calculate([=] (const std::vector<::sycl::event>& events,
size_t ndata,
bst_float* io_preds) {
return qu_.submit([&](::sycl::handler& cgh) {
cgh.depends_on(events);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
io_preds[idx] = Loss::PredTransform(io_preds[idx]);
});
});
}).wait();
}, io_preds);
qu_.wait_and_throw();
}
float ProbToMargin(float base_score) const override {
@@ -163,6 +188,8 @@ class RegLossObj : public ObjFunction {
sycl::DeviceManager device_manager;
mutable ::sycl::queue qu_;
static constexpr size_t kBatchSize = 1u << 22;
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
};
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,