From 0c440677369c743037ac21ae0870c650ec0aed19 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Sat, 8 Jun 2024 05:53:23 +0200 Subject: [PATCH] [SYCL] Optimize gradients calculations. (#10325) --------- Co-authored-by: Dmitry Razdoburdin <> --- plugin/sycl/common/linalg_op.h | 240 ++++++++++++++++++++++++ plugin/sycl/objective/multiclass_obj.cc | 90 ++++++--- plugin/sycl/objective/regression_obj.cc | 139 ++++++++------ 3 files changed, 382 insertions(+), 87 deletions(-) create mode 100644 plugin/sycl/common/linalg_op.h diff --git a/plugin/sycl/common/linalg_op.h b/plugin/sycl/common/linalg_op.h new file mode 100644 index 000000000..07d4a7ef2 --- /dev/null +++ b/plugin/sycl/common/linalg_op.h @@ -0,0 +1,240 @@ +/** + * Copyright 2021-2024, XGBoost Contributors + * \file linalg_op.h + */ +#ifndef PLUGIN_SYCL_COMMON_LINALG_OP_H_ +#define PLUGIN_SYCL_COMMON_LINALG_OP_H_ + +#include +#include + +#include "../data.h" + +#include + +namespace xgboost { +namespace sycl { +namespace linalg { + +struct WorkGroupsParams { + size_t n_workgroups; + size_t workgroup_size; +}; + +template +::sycl::event GroupWiseKernel(::sycl::queue* qu, int* flag_ptr, + const std::vector<::sycl::event>& events, + const WorkGroupsParams& wg, Fn &&fn) { + ::sycl::buffer flag_buf(flag_ptr, 1); + auto event = qu->submit([&](::sycl::handler& cgh) { + cgh.depends_on(events); + auto flag = flag_buf.get_access<::sycl::access::mode::write>(cgh); + cgh.parallel_for_work_group<>(::sycl::range<1>(wg.n_workgroups), + ::sycl::range<1>(wg.workgroup_size), + [=](::sycl::group<1> group) { + group.parallel_for_work_item([&](::sycl::h_item<1> item) { + const size_t idx = item.get_global_id()[0]; + fn(idx, flag); + }); + }); + }); + return event; +} + +struct Argument { + template + operator T&&() const; +}; + +template +struct ArgumentsPassedImpl + : std::false_type {}; + +template +struct ArgumentsPassedImpl, + decltype(std::declval()(((void)Is, Argument{})...), void())> + : std::true_type {}; + +template +struct ArgumentsPassed : ArgumentsPassedImpl> {}; + +template +class BatchProcessingHelper { + public: + static constexpr size_t kBatchSize = BatchSize; + using InputType = HostDeviceVector; + using OutputType = HostDeviceVector; + + private: + template + void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input) { + /* + * Some inputs may have less than 1 sample per output symbol. + */ + const size_t sub_sample_rate = ndata_ * sample_rates_[NumInput+1] / input.Size(); + const size_t n_samples = batch_size_ * sample_rates_[NumInput+1] / sub_sample_rate; + + const InputDType* in_host_ptr = input.HostPointer() + + batch_begin_ * sample_rates_[NumInput+1] / sub_sample_rate; + + events_[NumInput] = + qu_->memcpy(in_buffer_ptr, in_host_ptr, n_samples * sizeof(InputDType), + events_[MaxNumInputs - 2]); + } + + template + void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input, + const InputTypes&... other_inputs) { + // Make copy for the first input in the list + Host2Buffers(in_buffer_ptr, input); + // Recurent call for next inputs + InputDType* next_input = in_buffer_.Data() + in_buff_offsets_[NumInput + 1]; + Host2Buffers(next_input, other_inputs...); + } + + void Buffers2Host(OutputType* output) { + const size_t n_samples = batch_size_ * sample_rates_[0]; + OutputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[0]; + events_[MaxNumInputs - 1] = + qu_->memcpy(out_host_ptr, out_buffer_.DataConst(), n_samples * sizeof(OutputDType), + events_[MaxNumInputs - 2]); + } + + void Buffers2Host(InputType* output) { + const size_t n_samples = batch_size_ * sample_rates_[1]; + InputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[1]; + events_[MaxNumInputs - 1] = + qu_->memcpy(out_host_ptr, in_buffer_.DataConst(), n_samples * sizeof(InputDType), + events_[MaxNumInputs - 2]); + } + + template + void Call(Fn &&fn, const InputDType* input, const InputTypes*... other_inputs) { + static_assert(NumInputs <= MaxNumInputs, + "To many arguments in the passed function"); + /* Passed lambda may have less inputs than MaxNumInputs, + * need to pass only requared number of arguments + */ + // 1 for events, 1 for batch_size, 1 for output + if constexpr (ArgumentsPassed::value) { + events_[MaxNumInputs - 2] = fn(events_, batch_size_, + out_buffer_.Data(), input, other_inputs...); + } else { + const InputDType* next_input = in_buffer_.DataConst() + + in_buff_offsets_[MaxNumInputs - 1 - NumInputs]; + Call(std::forward(fn), next_input, input, other_inputs...); + } + } + + template + void Call(Fn &&fn, InputDType* io, const InputDType* input, const InputTypes*... other_inputs) { + static_assert(NumInputs <= MaxNumInputs, + "To many arguments in the passed function"); + if constexpr (ArgumentsPassed::value) { + events_[MaxNumInputs - 2] = fn(events_, batch_size_, + io, input, other_inputs...); + } else { + const InputDType* next_input = in_buffer_.DataConst() + + in_buff_offsets_[MaxNumInputs - NumInputs]; + Call(std::forward(fn), io, next_input, input, other_inputs...); + } + } + + template + void Call(Fn &&fn, InputDType* io) { + static_assert(NumInputs <= MaxNumInputs, + "To many arguments in the passed function"); + if constexpr (ArgumentsPassed::value) { + events_[MaxNumInputs - 2] = fn(events_, batch_size_, io); + } else { + const InputDType* next_input = in_buffer_.DataConst() + + in_buff_offsets_[MaxNumInputs - 1]; + Call(std::forward(fn), io, next_input); + } + } + + public: + BatchProcessingHelper() = default; + + // The first element of sample_rate always corresonds to output sample rate + void InitBuffers(::sycl::queue* qu, const std::vector& sample_rate) { + assert(sample_rate.size() == MaxNumInputs + 1); + sample_rates_ = sample_rate; + qu_ = qu; + events_.resize(MaxNumInputs + 2); + out_buffer_.Resize(qu, kBatchSize * sample_rate.front()); + + in_buff_offsets_[0] = 0; + for (size_t i = 1; i < MaxNumInputs; ++i) { + in_buff_offsets_[i] = in_buff_offsets_[i - 1] + kBatchSize * sample_rate[i]; + } + const size_t in_buff_size = in_buff_offsets_.back() + kBatchSize * sample_rate.back(); + in_buffer_.Resize(qu, in_buff_size); + } + + /* + * Batch-wise proces on sycl device + * output = fn(inputs) + */ + template + void Calculate(Fn &&fn, OutputType* output, const InputTypes&... inputs) { + ndata_ = output->Size() / sample_rates_.front(); + const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0); + for (size_t batch = 0; batch < nBatch; ++batch) { + batch_begin_ = batch * kBatchSize; + batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize; + batch_size_ = batch_end_ - batch_begin_; + + // Iteratively copy all inputs to device buffers + Host2Buffers(in_buffer_.Data(), inputs...); + // Pack buffers and call function + // We shift input pointer to keep the same order of inputs after packing + Call(std::forward(fn), in_buffer_.DataConst() + in_buff_offsets_.back()); + // Copy results to host + Buffers2Host(output); + } + } + + /* + * Batch-wise proces on sycl device + * input = fn(input, other_inputs) + */ + template + void Calculate(Fn &&fn, InputType* input, const InputTypes&... other_inputs) { + ndata_ = input->Size(); + const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0); + for (size_t batch = 0; batch < nBatch; ++batch) { + batch_begin_ = batch * kBatchSize; + batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize; + batch_size_ = batch_end_ - batch_begin_; + + // Iteratively copy all inputs to device buffers. + // inputs are pased by const reference + Host2Buffers(in_buffer_.Data(), *(input), other_inputs...); + // Pack buffers and call function + // We shift input pointer to keep the same order of inputs after packing + Call(std::forward(fn), in_buffer_.Data()); + // Copy results to host + Buffers2Host(input); + } + } + + private: + std::array in_buff_offsets_; + std::vector sample_rates_; + size_t ndata_; + size_t batch_begin_; + size_t batch_end_; + // is not equal to kBatchSize for the last batch + size_t batch_size_; + ::sycl::queue* qu_; + std::vector<::sycl::event> events_; + USMVector in_buffer_; + USMVector out_buffer_; +}; + +} // namespace linalg +} // namespace sycl +} // namespace xgboost +#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_ diff --git a/plugin/sycl/objective/multiclass_obj.cc b/plugin/sycl/objective/multiclass_obj.cc index 5dcc8c3de..25668c830 100644 --- a/plugin/sycl/objective/multiclass_obj.cc +++ b/plugin/sycl/objective/multiclass_obj.cc @@ -22,7 +22,10 @@ #include "../../../src/objective/multiclass_param.h" +#include "../common/linalg_op.h" + #include "../device_manager.h" +#include "../data.h" #include namespace xgboost { @@ -32,6 +35,15 @@ namespace obj { DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl); class SoftmaxMultiClassObj : public ObjFunction { + mutable bool are_buffs_init = false; + + void InitBuffers(const std::vector& sample_rate) const { + if (!are_buffs_init) { + batch_processor_.InitBuffers(&qu_, sample_rate); + are_buffs_init = true; + } + } + public: explicit SoftmaxMultiClassObj(bool output_prob) : output_prob_(output_prob) {} @@ -44,7 +56,7 @@ class SoftmaxMultiClassObj : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo& info, int iter, - linalg::Matrix* out_gpair) override { + xgboost::linalg::Matrix* out_gpair) override { if (preds.Size() == 0) return; if (info.labels.Size() == 0) return; @@ -66,54 +78,68 @@ class SoftmaxMultiClassObj : public ObjFunction { << "Number of weights should be equal to number of data points."; } - ::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); - ::sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); - ::sycl::buffer out_gpair_buf(out_gpair->Data()->HostPointer(), - out_gpair->Size()); - ::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), - is_null_weight ? 1 : info.weights_.Size()); - int flag = 1; - { - ::sycl::buffer flag_buf(&flag, 1); - qu_.submit([&](::sycl::handler& cgh) { - auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh); - auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh); - auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh); - auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh); - auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh); - cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { - int idx = pid[0]; - - bst_float const * point = &preds_acc[idx * nclass]; + auto objective_fn = [=, &flag] + (const std::vector<::sycl::event>& events, + size_t ndata, + GradientPair* out_gpair, + const bst_float* preds, + const bst_float* labels, + const bst_float* weights) { + const size_t wg_size = 32; + const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0); + return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size}, + [=] (size_t idx, auto flag) { + const bst_float* pred = preds + idx * nclass; // Part of Softmax function bst_float wmax = std::numeric_limits::min(); - for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); } - float wsum = 0.0f; - for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); } - auto label = labels_acc[idx]; + for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(pred[k], wmax); } + bst_float wsum = 0.0f; + for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(pred[k] - wmax); } + bst_float label = labels[idx]; + if (label < 0 || label >= nclass) { - flag_buf_acc[0] = 0; + AtomicRef flag_ref(flag[0]); + flag_ref = 0; label = 0; } - bst_float wt = is_null_weight ? 1.0f : weights_acc[idx]; + + bst_float wt = is_null_weight ? 1.0f : weights[idx]; for (int k = 0; k < nclass; ++k) { - bst_float p = expf(point[k] - wmax) / static_cast(wsum); + bst_float p = expf(pred[k] - wmax) / static_cast(wsum); const float eps = 1e-16f; const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps); p = label == k ? p - 1.0f : p; - out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h); + out_gpair[idx * nclass + k] = GradientPair(p * wt, h); } - }); - }).wait(); + }); + }; + + // out_gpair and preds have nclass points per sample + // labels and weights have 1 points per sample + InitBuffers({nclass, nclass, 1, 1}); + if (is_null_weight) { + // Output is passed by pointer + // Inputs are passed by const reference + batch_processor_.Calculate(std::move(objective_fn), + out_gpair->Data(), + preds, + *(info.labels.Data())); + } else { + batch_processor_.Calculate(std::move(objective_fn), + out_gpair->Data(), + preds, + *(info.labels.Data()), + info.weights_); } - // flag_buf is destroyed, content is copyed to the "flag" + qu_.wait_and_throw(); if (flag == 0) { LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class)."; } } + void PredTransform(HostDeviceVector* io_preds) const override { this->Transform(io_preds, output_prob_); } @@ -190,6 +216,8 @@ class SoftmaxMultiClassObj : public ObjFunction { sycl::DeviceManager device_manager; mutable ::sycl::queue qu_; + static constexpr size_t kBatchSize = 1u << 22; + mutable linalg::BatchProcessingHelper batch_processor_; }; XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl") diff --git a/plugin/sycl/objective/regression_obj.cc b/plugin/sycl/objective/regression_obj.cc index 82467a7c4..ee75270fa 100644 --- a/plugin/sycl/objective/regression_obj.cc +++ b/plugin/sycl/objective/regression_obj.cc @@ -27,7 +27,10 @@ #pragma GCC diagnostic pop #include "../../../src/objective/regression_param.h" +#include "../common/linalg_op.h" + #include "../device_manager.h" +#include "../data.h" #include @@ -41,6 +44,14 @@ template class RegLossObj : public ObjFunction { protected: HostDeviceVector label_correct_; + mutable bool are_buffs_init = false; + + void InitBuffers() const { + if (!are_buffs_init) { + batch_processor_.InitBuffers(&qu_, {1, 1, 1, 1}); + are_buffs_init = true; + } + } public: RegLossObj() = default; @@ -53,63 +64,72 @@ class RegLossObj : public ObjFunction { void GetGradient(const HostDeviceVector& preds, const MetaInfo &info, int iter, - linalg::Matrix* out_gpair) override { - if (info.labels.Size() == 0) return; - CHECK_EQ(preds.Size(), info.labels.Size()) - << " " << "labels are not correctly provided" - << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " - << "Loss: " << Loss::Name(); + xgboost::linalg::Matrix* out_gpair) override { + if (info.labels.Size() == 0) return; + CHECK_EQ(preds.Size(), info.labels.Size()) + << " " << "labels are not correctly provided" + << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " + << "Loss: " << Loss::Name(); - size_t const ndata = preds.Size(); - auto const n_targets = this->Targets(info); - out_gpair->Reshape(info.num_row_, n_targets); + size_t const ndata = preds.Size(); + auto const n_targets = this->Targets(info); + out_gpair->Reshape(info.num_row_, n_targets); - // TODO(razdoburdin): add label_correct check - label_correct_.Resize(1); - label_correct_.Fill(1); + // TODO(razdoburdin): add label_correct check + label_correct_.Resize(1); + label_correct_.Fill(1); - bool is_null_weight = info.weights_.Size() == 0; + bool is_null_weight = info.weights_.Size() == 0; - ::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); - ::sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); - ::sycl::buffer out_gpair_buf(out_gpair->Data()->HostPointer(), - out_gpair->Size()); - ::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), - is_null_weight ? 1 : info.weights_.Size()); + auto scale_pos_weight = param_.scale_pos_weight; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) + << "Number of weights should be equal to number of data points."; + } - auto scale_pos_weight = param_.scale_pos_weight; - if (!is_null_weight) { - CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) - << "Number of weights should be equal to number of data points."; - } - - int flag = 1; - { - ::sycl::buffer flag_buf(&flag, 1); - qu_.submit([&](::sycl::handler& cgh) { - auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh); - auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh); - auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh); - auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh); - auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh); - cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { - int idx = pid[0]; - bst_float p = Loss::PredTransform(preds_acc[idx]); - bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets]; - bst_float label = labels_acc[idx]; + int flag = 1; + auto objective_fn = [=, &flag] + (const std::vector<::sycl::event>& events, + size_t ndata, + GradientPair* out_gpair, + const bst_float* preds, + const bst_float* labels, + const bst_float* weights) { + const size_t wg_size = 32; + const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0); + return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size}, + [=] (size_t idx, auto flag) { + const bst_float pred = Loss::PredTransform(preds[idx]); + bst_float weight = is_null_weight ? 1.0f : weights[idx/n_targets]; + const bst_float label = labels[idx]; if (label == 1.0f) { - w *= scale_pos_weight; + weight *= scale_pos_weight; } if (!Loss::CheckLabel(label)) { - // If there is an incorrect label, the host code will know. - flag_buf_acc[0] = 0; + AtomicRef flag_ref(flag[0]); + flag_ref = 0; } - out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, - Loss::SecondOrderGradient(p, label) * w); - }); - }).wait(); - } - // flag_buf is destroyed, content is copyed to the "flag" + out_gpair[idx] = GradientPair(Loss::FirstOrderGradient(pred, label) * weight, + Loss::SecondOrderGradient(pred, label) * weight); + }); + }; + + InitBuffers(); + if (is_null_weight) { + // Output is passed by pointer + // Inputs are passed by const reference + batch_processor_.Calculate(std::move(objective_fn), + out_gpair->Data(), + preds, + *(info.labels.Data())); + } else { + batch_processor_.Calculate(std::move(objective_fn), + out_gpair->Data(), + preds, + *(info.labels.Data()), + info.weights_); + } + qu_.wait_and_throw(); if (flag == 0) { LOG(FATAL) << Loss::LabelErrorMsg(); @@ -121,18 +141,23 @@ class RegLossObj : public ObjFunction { return Loss::DefaultEvalMetric(); } - void PredTransform(HostDeviceVector *io_preds) const override { + void PredTransform(HostDeviceVector *io_preds) const override { size_t const ndata = io_preds->Size(); if (ndata == 0) return; - ::sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); + InitBuffers(); - qu_.submit([&](::sycl::handler& cgh) { - auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh); - cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { - int idx = pid[0]; - io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]); + batch_processor_.Calculate([=] (const std::vector<::sycl::event>& events, + size_t ndata, + bst_float* io_preds) { + return qu_.submit([&](::sycl::handler& cgh) { + cgh.depends_on(events); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + io_preds[idx] = Loss::PredTransform(io_preds[idx]); + }); }); - }).wait(); + }, io_preds); + qu_.wait_and_throw(); } float ProbToMargin(float base_score) const override { @@ -163,6 +188,8 @@ class RegLossObj : public ObjFunction { sycl::DeviceManager device_manager; mutable ::sycl::queue qu_; + static constexpr size_t kBatchSize = 1u << 22; + mutable linalg::BatchProcessingHelper batch_processor_; }; XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,