Sycl implementation for objective functions (#9846)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2023-12-12 07:41:50 +01:00
committed by GitHub
parent ddab49a8be
commit 43897b8296
19 changed files with 1129 additions and 423 deletions

View File

@@ -0,0 +1,210 @@
/*!
* Copyright 2015-2023 by Contributors
* \file multiclass_obj.cc
* \brief Definition of multi-class classification objectives.
*/
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <rabit/rabit.h>
#pragma GCC diagnostic pop
#include <vector>
#include <algorithm>
#include <limits>
#include <utility>
#include "xgboost/parameter.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "xgboost/data.h"
#include "../../src/common/math.h"
#pragma GCC diagnostic pop
#include "xgboost/logging.h"
#include "xgboost/objective.h"
#include "xgboost/json.h"
#include "xgboost/span.h"
#include "../../../src/objective/multiclass_param.h"
#include "../device_manager.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace obj {
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);
class SoftmaxMultiClassObj : public ObjFunction {
public:
explicit SoftmaxMultiClassObj(bool output_prob)
: output_prob_(output_prob) {}
void Configure(Args const& args) override {
param_.UpdateAllowUnknown(args);
qu_ = device_manager.GetQueue(ctx_->Device());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
linalg::Matrix<GradientPair>* out_gpair) override {
if (preds.Size() == 0) return;
if (info.labels.Size() == 0) return;
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels.Size()))
<< "SoftmaxMultiClassObj: label size and pred size does not match.\n"
<< "label.Size() * num_class: "
<< info.labels.Size() * static_cast<size_t>(param_.num_class) << "\n"
<< "num_class: " << param_.num_class << "\n"
<< "preds.Size(): " << preds.Size();
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
out_gpair->Reshape(info.num_row_, static_cast<std::uint64_t>(nclass));
const bool is_null_weight = info.weights_.Size() == 0;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), ndata)
<< "Number of weights should be equal to number of data points.";
}
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());
int flag = 1;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float const * point = &preds_acc[idx * nclass];
// Part of Softmax function
bst_float wmax = std::numeric_limits<bst_float>::min();
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
float wsum = 0.0f;
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
auto label = labels_acc[idx];
if (label < 0 || label >= nclass) {
flag_buf_acc[0] = 0;
label = 0;
}
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];
for (int k = 0; k < nclass; ++k) {
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
const float eps = 1e-16f;
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
p = label == k ? p - 1.0f : p;
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
}
});
}).wait();
}
// flag_buf is destroyed, content is copyed to the "flag"
if (flag == 0) {
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
}
}
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
this->Transform(io_preds, output_prob_);
}
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
this->Transform(io_preds, true);
}
const char* DefaultEvalMetric() const override {
return "mlogloss";
}
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
if (io_preds->Size() == 0) return;
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata);
{
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
if (prob) {
qu_.submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
auto it = io_preds_acc.begin() + idx * nclass;
common::Softmax(it, it + nclass);
});
}).wait();
} else {
::sycl::buffer<bst_float, 1> max_preds_buf(max_preds_.HostPointer(), max_preds_.Size());
qu_.submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto max_preds_acc = max_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
auto it = io_preds_acc.begin() + idx * nclass;
max_preds_acc[idx] = common::FindMaxIndex(it, it + nclass) - it;
});
}).wait();
}
}
if (!prob) {
io_preds->Resize(max_preds_.Size());
io_preds->Copy(max_preds_);
}
}
struct ObjInfo Task() const override {return {ObjInfo::kClassification}; }
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
if (this->output_prob_) {
out["name"] = String("multi:softprob");
} else {
out["name"] = String("multi:softmax");
}
out["softmax_multiclass_param"] = ToJson(param_);
}
void LoadConfig(Json const& in) override {
FromJson(in["softmax_multiclass_param"], &param_);
}
private:
// output probability
bool output_prob_;
// parameter
xgboost::obj::SoftmaxMultiClassParam param_;
// Cache for max_preds
mutable HostDeviceVector<bst_float> max_preds_;
sycl::DeviceManager device_manager;
mutable ::sycl::queue qu_;
};
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
.describe("Softmax for multi-class classification, output class index.")
.set_body([]() { return new SoftmaxMultiClassObj(false); });
XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob_sycl")
.describe("Softmax for multi-class classification, output probability distribution.")
.set_body([]() { return new SoftmaxMultiClassObj(true); });
} // namespace obj
} // namespace sycl
} // namespace xgboost

View File

@@ -0,0 +1,197 @@
/*!
* Copyright 2015-2023 by Contributors
* \file regression_obj.cc
* \brief Definition of regression objectives.
*/
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#pragma GCC diagnostic pop
#include <rabit/rabit.h>
#include <cmath>
#include <memory>
#include <vector>
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "../../src/common/transform.h"
#include "../../src/common/common.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#include "../../../src/objective/regression_loss.h"
#pragma GCC diagnostic pop
#include "../../../src/objective/regression_param.h"
#include "../device_manager.h"
#include <CL/sycl.hpp>
namespace xgboost {
namespace sycl {
namespace obj {
DMLC_REGISTRY_FILE_TAG(regression_obj_sycl);
template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
public:
RegLossObj() = default;
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
qu_ = device_manager.GetQueue(ctx_->Device());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
linalg::Matrix<GradientPair>* out_gpair) override {
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
<< "Loss: " << Loss::Name();
size_t const ndata = preds.Size();
auto const n_targets = this->Targets(info);
out_gpair->Reshape(info.num_row_, n_targets);
// TODO(razdoburdin): add label_correct check
label_correct_.Resize(1);
label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0;
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
out_gpair->Size());
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
is_null_weight ? 1 : info.weights_.Size());
auto scale_pos_weight = param_.scale_pos_weight;
if (!is_null_weight) {
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
<< "Number of weights should be equal to number of data points.";
}
int flag = 1;
{
::sycl::buffer<int, 1> flag_buf(&flag, 1);
qu_.submit([&](::sycl::handler& cgh) {
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
bst_float p = Loss::PredTransform(preds_acc[idx]);
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
bst_float label = labels_acc[idx];
if (label == 1.0f) {
w *= scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
flag_buf_acc[0] = 0;
}
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
});
}).wait();
}
// flag_buf is destroyed, content is copyed to the "flag"
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}
public:
const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(HostDeviceVector<float> *io_preds) const override {
size_t const ndata = io_preds->Size();
if (ndata == 0) return;
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
qu_.submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
});
}).wait();
}
float ProbToMargin(float base_score) const override {
return Loss::ProbToMargin(base_score);
}
struct ObjInfo Task() const override {
return Loss::Info();
};
uint32_t Targets(MetaInfo const& info) const override {
// Multi-target regression.
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
}
void SaveConfig(Json* p_out) const override {
auto& out = *p_out;
out["name"] = String(Loss::Name());
out["reg_loss_param"] = ToJson(param_);
}
void LoadConfig(Json const& in) override {
FromJson(in["reg_loss_param"], &param_);
}
protected:
xgboost::obj::RegLossParam param_;
sycl::DeviceManager device_manager;
mutable ::sycl::queue qu_;
};
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,
std::string(xgboost::obj::LinearSquareLoss::Name()) + "_sycl")
.describe("Regression with squared error with SYCL backend.")
.set_body([]() { return new RegLossObj<xgboost::obj::LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(SquareLogError,
std::string(xgboost::obj::SquaredLogError::Name()) + "_sycl")
.describe("Regression with root mean squared logarithmic error with SYCL backend.")
.set_body([]() { return new RegLossObj<xgboost::obj::SquaredLogError>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression,
std::string(xgboost::obj::LogisticRegression::Name()) + "_sycl")
.describe("Logistic regression for probability regression task with SYCL backend.")
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification,
std::string(xgboost::obj::LogisticClassification::Name()) + "_sycl")
.describe("Logistic regression for binary classification task with SYCL backend.")
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw,
std::string(xgboost::obj::LogisticRaw::Name()) + "_sycl")
.describe("Logistic regression for classification, output score "
"before logistic transformation with SYCL backend.")
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticRaw>(); });
} // namespace obj
} // namespace sycl
} // namespace xgboost