[SYCL] Optimize gradients calculations. (#10325)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
c9f5fcaf21
commit
0c44067736
240
plugin/sycl/common/linalg_op.h
Normal file
240
plugin/sycl/common/linalg_op.h
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2021-2024, XGBoost Contributors
|
||||||
|
* \file linalg_op.h
|
||||||
|
*/
|
||||||
|
#ifndef PLUGIN_SYCL_COMMON_LINALG_OP_H_
|
||||||
|
#define PLUGIN_SYCL_COMMON_LINALG_OP_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "../data.h"
|
||||||
|
|
||||||
|
#include <CL/sycl.hpp>
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
namespace sycl {
|
||||||
|
namespace linalg {
|
||||||
|
|
||||||
|
struct WorkGroupsParams {
|
||||||
|
size_t n_workgroups;
|
||||||
|
size_t workgroup_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn>
|
||||||
|
::sycl::event GroupWiseKernel(::sycl::queue* qu, int* flag_ptr,
|
||||||
|
const std::vector<::sycl::event>& events,
|
||||||
|
const WorkGroupsParams& wg, Fn &&fn) {
|
||||||
|
::sycl::buffer<int, 1> flag_buf(flag_ptr, 1);
|
||||||
|
auto event = qu->submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(events);
|
||||||
|
auto flag = flag_buf.get_access<::sycl::access::mode::write>(cgh);
|
||||||
|
cgh.parallel_for_work_group<>(::sycl::range<1>(wg.n_workgroups),
|
||||||
|
::sycl::range<1>(wg.workgroup_size),
|
||||||
|
[=](::sycl::group<1> group) {
|
||||||
|
group.parallel_for_work_item([&](::sycl::h_item<1> item) {
|
||||||
|
const size_t idx = item.get_global_id()[0];
|
||||||
|
fn(idx, flag);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
return event;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Argument {
|
||||||
|
template <typename T>
|
||||||
|
operator T&&() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Fn, typename Is, typename = void>
|
||||||
|
struct ArgumentsPassedImpl
|
||||||
|
: std::false_type {};
|
||||||
|
|
||||||
|
template <typename Fn, size_t ...Is>
|
||||||
|
struct ArgumentsPassedImpl<Fn, std::index_sequence<Is...>,
|
||||||
|
decltype(std::declval<Fn>()(((void)Is, Argument{})...), void())>
|
||||||
|
: std::true_type {};
|
||||||
|
|
||||||
|
template <typename Fn, size_t N>
|
||||||
|
struct ArgumentsPassed : ArgumentsPassedImpl<Fn, std::make_index_sequence<N>> {};
|
||||||
|
|
||||||
|
template <typename OutputDType, typename InputDType,
|
||||||
|
size_t BatchSize, size_t MaxNumInputs>
|
||||||
|
class BatchProcessingHelper {
|
||||||
|
public:
|
||||||
|
static constexpr size_t kBatchSize = BatchSize;
|
||||||
|
using InputType = HostDeviceVector<InputDType>;
|
||||||
|
using OutputType = HostDeviceVector<OutputDType>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <size_t NumInput = 0>
|
||||||
|
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input) {
|
||||||
|
/*
|
||||||
|
* Some inputs may have less than 1 sample per output symbol.
|
||||||
|
*/
|
||||||
|
const size_t sub_sample_rate = ndata_ * sample_rates_[NumInput+1] / input.Size();
|
||||||
|
const size_t n_samples = batch_size_ * sample_rates_[NumInput+1] / sub_sample_rate;
|
||||||
|
|
||||||
|
const InputDType* in_host_ptr = input.HostPointer() +
|
||||||
|
batch_begin_ * sample_rates_[NumInput+1] / sub_sample_rate;
|
||||||
|
|
||||||
|
events_[NumInput] =
|
||||||
|
qu_->memcpy(in_buffer_ptr, in_host_ptr, n_samples * sizeof(InputDType),
|
||||||
|
events_[MaxNumInputs - 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t NumInput = 0, class... InputTypes>
|
||||||
|
void Host2Buffers(InputDType* in_buffer_ptr, const InputType& input,
|
||||||
|
const InputTypes&... other_inputs) {
|
||||||
|
// Make copy for the first input in the list
|
||||||
|
Host2Buffers<NumInput>(in_buffer_ptr, input);
|
||||||
|
// Recurent call for next inputs
|
||||||
|
InputDType* next_input = in_buffer_.Data() + in_buff_offsets_[NumInput + 1];
|
||||||
|
Host2Buffers<NumInput+1>(next_input, other_inputs...);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Buffers2Host(OutputType* output) {
|
||||||
|
const size_t n_samples = batch_size_ * sample_rates_[0];
|
||||||
|
OutputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[0];
|
||||||
|
events_[MaxNumInputs - 1] =
|
||||||
|
qu_->memcpy(out_host_ptr, out_buffer_.DataConst(), n_samples * sizeof(OutputDType),
|
||||||
|
events_[MaxNumInputs - 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void Buffers2Host(InputType* output) {
|
||||||
|
const size_t n_samples = batch_size_ * sample_rates_[1];
|
||||||
|
InputDType* out_host_ptr = output->HostPointer() + batch_begin_* sample_rates_[1];
|
||||||
|
events_[MaxNumInputs - 1] =
|
||||||
|
qu_->memcpy(out_host_ptr, in_buffer_.DataConst(), n_samples * sizeof(InputDType),
|
||||||
|
events_[MaxNumInputs - 2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t NumInputs = 1, typename Fn, class... InputTypes>
|
||||||
|
void Call(Fn &&fn, const InputDType* input, const InputTypes*... other_inputs) {
|
||||||
|
static_assert(NumInputs <= MaxNumInputs,
|
||||||
|
"To many arguments in the passed function");
|
||||||
|
/* Passed lambda may have less inputs than MaxNumInputs,
|
||||||
|
* need to pass only requared number of arguments
|
||||||
|
*/
|
||||||
|
// 1 for events, 1 for batch_size, 1 for output
|
||||||
|
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1 + 1>::value) {
|
||||||
|
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
|
||||||
|
out_buffer_.Data(), input, other_inputs...);
|
||||||
|
} else {
|
||||||
|
const InputDType* next_input = in_buffer_.DataConst() +
|
||||||
|
in_buff_offsets_[MaxNumInputs - 1 - NumInputs];
|
||||||
|
Call<NumInputs+1>(std::forward<Fn>(fn), next_input, input, other_inputs...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t NumInputs = 1, typename Fn, class... InputTypes>
|
||||||
|
void Call(Fn &&fn, InputDType* io, const InputDType* input, const InputTypes*... other_inputs) {
|
||||||
|
static_assert(NumInputs <= MaxNumInputs,
|
||||||
|
"To many arguments in the passed function");
|
||||||
|
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
|
||||||
|
events_[MaxNumInputs - 2] = fn(events_, batch_size_,
|
||||||
|
io, input, other_inputs...);
|
||||||
|
} else {
|
||||||
|
const InputDType* next_input = in_buffer_.DataConst() +
|
||||||
|
in_buff_offsets_[MaxNumInputs - NumInputs];
|
||||||
|
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input, input, other_inputs...);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t NumInputs = 1, typename Fn>
|
||||||
|
void Call(Fn &&fn, InputDType* io) {
|
||||||
|
static_assert(NumInputs <= MaxNumInputs,
|
||||||
|
"To many arguments in the passed function");
|
||||||
|
if constexpr (ArgumentsPassed<Fn, NumInputs + 1 + 1>::value) {
|
||||||
|
events_[MaxNumInputs - 2] = fn(events_, batch_size_, io);
|
||||||
|
} else {
|
||||||
|
const InputDType* next_input = in_buffer_.DataConst() +
|
||||||
|
in_buff_offsets_[MaxNumInputs - 1];
|
||||||
|
Call<NumInputs+1>(std::forward<Fn>(fn), io, next_input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
BatchProcessingHelper() = default;
|
||||||
|
|
||||||
|
// The first element of sample_rate always corresonds to output sample rate
|
||||||
|
void InitBuffers(::sycl::queue* qu, const std::vector<int>& sample_rate) {
|
||||||
|
assert(sample_rate.size() == MaxNumInputs + 1);
|
||||||
|
sample_rates_ = sample_rate;
|
||||||
|
qu_ = qu;
|
||||||
|
events_.resize(MaxNumInputs + 2);
|
||||||
|
out_buffer_.Resize(qu, kBatchSize * sample_rate.front());
|
||||||
|
|
||||||
|
in_buff_offsets_[0] = 0;
|
||||||
|
for (size_t i = 1; i < MaxNumInputs; ++i) {
|
||||||
|
in_buff_offsets_[i] = in_buff_offsets_[i - 1] + kBatchSize * sample_rate[i];
|
||||||
|
}
|
||||||
|
const size_t in_buff_size = in_buff_offsets_.back() + kBatchSize * sample_rate.back();
|
||||||
|
in_buffer_.Resize(qu, in_buff_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Batch-wise proces on sycl device
|
||||||
|
* output = fn(inputs)
|
||||||
|
*/
|
||||||
|
template <typename Fn, class... InputTypes>
|
||||||
|
void Calculate(Fn &&fn, OutputType* output, const InputTypes&... inputs) {
|
||||||
|
ndata_ = output->Size() / sample_rates_.front();
|
||||||
|
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
|
||||||
|
for (size_t batch = 0; batch < nBatch; ++batch) {
|
||||||
|
batch_begin_ = batch * kBatchSize;
|
||||||
|
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
|
||||||
|
batch_size_ = batch_end_ - batch_begin_;
|
||||||
|
|
||||||
|
// Iteratively copy all inputs to device buffers
|
||||||
|
Host2Buffers(in_buffer_.Data(), inputs...);
|
||||||
|
// Pack buffers and call function
|
||||||
|
// We shift input pointer to keep the same order of inputs after packing
|
||||||
|
Call(std::forward<Fn>(fn), in_buffer_.DataConst() + in_buff_offsets_.back());
|
||||||
|
// Copy results to host
|
||||||
|
Buffers2Host(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* Batch-wise proces on sycl device
|
||||||
|
* input = fn(input, other_inputs)
|
||||||
|
*/
|
||||||
|
template <typename Fn, class... InputTypes>
|
||||||
|
void Calculate(Fn &&fn, InputType* input, const InputTypes&... other_inputs) {
|
||||||
|
ndata_ = input->Size();
|
||||||
|
const size_t nBatch = ndata_ / kBatchSize + (ndata_ % kBatchSize > 0);
|
||||||
|
for (size_t batch = 0; batch < nBatch; ++batch) {
|
||||||
|
batch_begin_ = batch * kBatchSize;
|
||||||
|
batch_end_ = (batch == nBatch - 1) ? ndata_ : batch_begin_ + kBatchSize;
|
||||||
|
batch_size_ = batch_end_ - batch_begin_;
|
||||||
|
|
||||||
|
// Iteratively copy all inputs to device buffers.
|
||||||
|
// inputs are pased by const reference
|
||||||
|
Host2Buffers(in_buffer_.Data(), *(input), other_inputs...);
|
||||||
|
// Pack buffers and call function
|
||||||
|
// We shift input pointer to keep the same order of inputs after packing
|
||||||
|
Call(std::forward<Fn>(fn), in_buffer_.Data());
|
||||||
|
// Copy results to host
|
||||||
|
Buffers2Host(input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::array<int, MaxNumInputs> in_buff_offsets_;
|
||||||
|
std::vector<int> sample_rates_;
|
||||||
|
size_t ndata_;
|
||||||
|
size_t batch_begin_;
|
||||||
|
size_t batch_end_;
|
||||||
|
// is not equal to kBatchSize for the last batch
|
||||||
|
size_t batch_size_;
|
||||||
|
::sycl::queue* qu_;
|
||||||
|
std::vector<::sycl::event> events_;
|
||||||
|
USMVector<InputDType, MemoryType::on_device> in_buffer_;
|
||||||
|
USMVector<OutputDType, MemoryType::on_device> out_buffer_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace linalg
|
||||||
|
} // namespace sycl
|
||||||
|
} // namespace xgboost
|
||||||
|
#endif // PLUGIN_SYCL_COMMON_LINALG_OP_H_
|
||||||
@ -22,7 +22,10 @@
|
|||||||
|
|
||||||
#include "../../../src/objective/multiclass_param.h"
|
#include "../../../src/objective/multiclass_param.h"
|
||||||
|
|
||||||
|
#include "../common/linalg_op.h"
|
||||||
|
|
||||||
#include "../device_manager.h"
|
#include "../device_manager.h"
|
||||||
|
#include "../data.h"
|
||||||
#include <CL/sycl.hpp>
|
#include <CL/sycl.hpp>
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -32,6 +35,15 @@ namespace obj {
|
|||||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);
|
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);
|
||||||
|
|
||||||
class SoftmaxMultiClassObj : public ObjFunction {
|
class SoftmaxMultiClassObj : public ObjFunction {
|
||||||
|
mutable bool are_buffs_init = false;
|
||||||
|
|
||||||
|
void InitBuffers(const std::vector<int>& sample_rate) const {
|
||||||
|
if (!are_buffs_init) {
|
||||||
|
batch_processor_.InitBuffers(&qu_, sample_rate);
|
||||||
|
are_buffs_init = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||||
: output_prob_(output_prob) {}
|
: output_prob_(output_prob) {}
|
||||||
@ -44,7 +56,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
int iter,
|
int iter,
|
||||||
linalg::Matrix<GradientPair>* out_gpair) override {
|
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
|
||||||
if (preds.Size() == 0) return;
|
if (preds.Size() == 0) return;
|
||||||
if (info.labels.Size() == 0) return;
|
if (info.labels.Size() == 0) return;
|
||||||
|
|
||||||
@ -66,54 +78,68 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
<< "Number of weights should be equal to number of data points.";
|
<< "Number of weights should be equal to number of data points.";
|
||||||
}
|
}
|
||||||
|
|
||||||
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
|
|
||||||
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
|
|
||||||
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
|
|
||||||
out_gpair->Size());
|
|
||||||
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
|
|
||||||
is_null_weight ? 1 : info.weights_.Size());
|
|
||||||
|
|
||||||
int flag = 1;
|
int flag = 1;
|
||||||
{
|
auto objective_fn = [=, &flag]
|
||||||
::sycl::buffer<int, 1> flag_buf(&flag, 1);
|
(const std::vector<::sycl::event>& events,
|
||||||
qu_.submit([&](::sycl::handler& cgh) {
|
size_t ndata,
|
||||||
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
|
GradientPair* out_gpair,
|
||||||
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
|
const bst_float* preds,
|
||||||
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
|
const bst_float* labels,
|
||||||
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
|
const bst_float* weights) {
|
||||||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
|
const size_t wg_size = 32;
|
||||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
|
||||||
int idx = pid[0];
|
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
|
||||||
|
[=] (size_t idx, auto flag) {
|
||||||
bst_float const * point = &preds_acc[idx * nclass];
|
const bst_float* pred = preds + idx * nclass;
|
||||||
|
|
||||||
// Part of Softmax function
|
// Part of Softmax function
|
||||||
bst_float wmax = std::numeric_limits<bst_float>::min();
|
bst_float wmax = std::numeric_limits<bst_float>::min();
|
||||||
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
|
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(pred[k], wmax); }
|
||||||
float wsum = 0.0f;
|
bst_float wsum = 0.0f;
|
||||||
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
|
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(pred[k] - wmax); }
|
||||||
auto label = labels_acc[idx];
|
bst_float label = labels[idx];
|
||||||
|
|
||||||
if (label < 0 || label >= nclass) {
|
if (label < 0 || label >= nclass) {
|
||||||
flag_buf_acc[0] = 0;
|
AtomicRef<int> flag_ref(flag[0]);
|
||||||
|
flag_ref = 0;
|
||||||
label = 0;
|
label = 0;
|
||||||
}
|
}
|
||||||
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];
|
|
||||||
|
bst_float wt = is_null_weight ? 1.0f : weights[idx];
|
||||||
for (int k = 0; k < nclass; ++k) {
|
for (int k = 0; k < nclass; ++k) {
|
||||||
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
|
bst_float p = expf(pred[k] - wmax) / static_cast<float>(wsum);
|
||||||
const float eps = 1e-16f;
|
const float eps = 1e-16f;
|
||||||
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
|
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
|
||||||
p = label == k ? p - 1.0f : p;
|
p = label == k ? p - 1.0f : p;
|
||||||
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
|
out_gpair[idx * nclass + k] = GradientPair(p * wt, h);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}).wait();
|
};
|
||||||
|
|
||||||
|
// out_gpair and preds have nclass points per sample
|
||||||
|
// labels and weights have 1 points per sample
|
||||||
|
InitBuffers({nclass, nclass, 1, 1});
|
||||||
|
if (is_null_weight) {
|
||||||
|
// Output is passed by pointer
|
||||||
|
// Inputs are passed by const reference
|
||||||
|
batch_processor_.Calculate(std::move(objective_fn),
|
||||||
|
out_gpair->Data(),
|
||||||
|
preds,
|
||||||
|
*(info.labels.Data()));
|
||||||
|
} else {
|
||||||
|
batch_processor_.Calculate(std::move(objective_fn),
|
||||||
|
out_gpair->Data(),
|
||||||
|
preds,
|
||||||
|
*(info.labels.Data()),
|
||||||
|
info.weights_);
|
||||||
}
|
}
|
||||||
// flag_buf is destroyed, content is copyed to the "flag"
|
qu_.wait_and_throw();
|
||||||
|
|
||||||
if (flag == 0) {
|
if (flag == 0) {
|
||||||
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
|
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
|
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
|
||||||
this->Transform(io_preds, output_prob_);
|
this->Transform(io_preds, output_prob_);
|
||||||
}
|
}
|
||||||
@ -190,6 +216,8 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
|||||||
sycl::DeviceManager device_manager;
|
sycl::DeviceManager device_manager;
|
||||||
|
|
||||||
mutable ::sycl::queue qu_;
|
mutable ::sycl::queue qu_;
|
||||||
|
static constexpr size_t kBatchSize = 1u << 22;
|
||||||
|
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
|
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
|
||||||
|
|||||||
@ -27,7 +27,10 @@
|
|||||||
#pragma GCC diagnostic pop
|
#pragma GCC diagnostic pop
|
||||||
#include "../../../src/objective/regression_param.h"
|
#include "../../../src/objective/regression_param.h"
|
||||||
|
|
||||||
|
#include "../common/linalg_op.h"
|
||||||
|
|
||||||
#include "../device_manager.h"
|
#include "../device_manager.h"
|
||||||
|
#include "../data.h"
|
||||||
|
|
||||||
#include <CL/sycl.hpp>
|
#include <CL/sycl.hpp>
|
||||||
|
|
||||||
@ -41,6 +44,14 @@ template<typename Loss>
|
|||||||
class RegLossObj : public ObjFunction {
|
class RegLossObj : public ObjFunction {
|
||||||
protected:
|
protected:
|
||||||
HostDeviceVector<int> label_correct_;
|
HostDeviceVector<int> label_correct_;
|
||||||
|
mutable bool are_buffs_init = false;
|
||||||
|
|
||||||
|
void InitBuffers() const {
|
||||||
|
if (!are_buffs_init) {
|
||||||
|
batch_processor_.InitBuffers(&qu_, {1, 1, 1, 1});
|
||||||
|
are_buffs_init = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
RegLossObj() = default;
|
RegLossObj() = default;
|
||||||
@ -53,7 +64,7 @@ class RegLossObj : public ObjFunction {
|
|||||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo &info,
|
const MetaInfo &info,
|
||||||
int iter,
|
int iter,
|
||||||
linalg::Matrix<GradientPair>* out_gpair) override {
|
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
|
||||||
if (info.labels.Size() == 0) return;
|
if (info.labels.Size() == 0) return;
|
||||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||||
<< " " << "labels are not correctly provided"
|
<< " " << "labels are not correctly provided"
|
||||||
@ -70,13 +81,6 @@ class RegLossObj : public ObjFunction {
|
|||||||
|
|
||||||
bool is_null_weight = info.weights_.Size() == 0;
|
bool is_null_weight = info.weights_.Size() == 0;
|
||||||
|
|
||||||
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
|
|
||||||
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
|
|
||||||
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
|
|
||||||
out_gpair->Size());
|
|
||||||
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
|
|
||||||
is_null_weight ? 1 : info.weights_.Size());
|
|
||||||
|
|
||||||
auto scale_pos_weight = param_.scale_pos_weight;
|
auto scale_pos_weight = param_.scale_pos_weight;
|
||||||
if (!is_null_weight) {
|
if (!is_null_weight) {
|
||||||
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
|
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
|
||||||
@ -84,32 +88,48 @@ class RegLossObj : public ObjFunction {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int flag = 1;
|
int flag = 1;
|
||||||
{
|
auto objective_fn = [=, &flag]
|
||||||
::sycl::buffer<int, 1> flag_buf(&flag, 1);
|
(const std::vector<::sycl::event>& events,
|
||||||
qu_.submit([&](::sycl::handler& cgh) {
|
size_t ndata,
|
||||||
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
|
GradientPair* out_gpair,
|
||||||
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
|
const bst_float* preds,
|
||||||
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
|
const bst_float* labels,
|
||||||
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
|
const bst_float* weights) {
|
||||||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
|
const size_t wg_size = 32;
|
||||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
|
||||||
int idx = pid[0];
|
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
|
||||||
bst_float p = Loss::PredTransform(preds_acc[idx]);
|
[=] (size_t idx, auto flag) {
|
||||||
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
|
const bst_float pred = Loss::PredTransform(preds[idx]);
|
||||||
bst_float label = labels_acc[idx];
|
bst_float weight = is_null_weight ? 1.0f : weights[idx/n_targets];
|
||||||
|
const bst_float label = labels[idx];
|
||||||
if (label == 1.0f) {
|
if (label == 1.0f) {
|
||||||
w *= scale_pos_weight;
|
weight *= scale_pos_weight;
|
||||||
}
|
}
|
||||||
if (!Loss::CheckLabel(label)) {
|
if (!Loss::CheckLabel(label)) {
|
||||||
// If there is an incorrect label, the host code will know.
|
AtomicRef<int> flag_ref(flag[0]);
|
||||||
flag_buf_acc[0] = 0;
|
flag_ref = 0;
|
||||||
}
|
}
|
||||||
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
|
out_gpair[idx] = GradientPair(Loss::FirstOrderGradient(pred, label) * weight,
|
||||||
Loss::SecondOrderGradient(p, label) * w);
|
Loss::SecondOrderGradient(pred, label) * weight);
|
||||||
});
|
});
|
||||||
}).wait();
|
};
|
||||||
|
|
||||||
|
InitBuffers();
|
||||||
|
if (is_null_weight) {
|
||||||
|
// Output is passed by pointer
|
||||||
|
// Inputs are passed by const reference
|
||||||
|
batch_processor_.Calculate(std::move(objective_fn),
|
||||||
|
out_gpair->Data(),
|
||||||
|
preds,
|
||||||
|
*(info.labels.Data()));
|
||||||
|
} else {
|
||||||
|
batch_processor_.Calculate(std::move(objective_fn),
|
||||||
|
out_gpair->Data(),
|
||||||
|
preds,
|
||||||
|
*(info.labels.Data()),
|
||||||
|
info.weights_);
|
||||||
}
|
}
|
||||||
// flag_buf is destroyed, content is copyed to the "flag"
|
qu_.wait_and_throw();
|
||||||
|
|
||||||
if (flag == 0) {
|
if (flag == 0) {
|
||||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||||
@ -121,18 +141,23 @@ class RegLossObj : public ObjFunction {
|
|||||||
return Loss::DefaultEvalMetric();
|
return Loss::DefaultEvalMetric();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredTransform(HostDeviceVector<float> *io_preds) const override {
|
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||||
size_t const ndata = io_preds->Size();
|
size_t const ndata = io_preds->Size();
|
||||||
if (ndata == 0) return;
|
if (ndata == 0) return;
|
||||||
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
|
InitBuffers();
|
||||||
|
|
||||||
qu_.submit([&](::sycl::handler& cgh) {
|
batch_processor_.Calculate([=] (const std::vector<::sycl::event>& events,
|
||||||
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
size_t ndata,
|
||||||
|
bst_float* io_preds) {
|
||||||
|
return qu_.submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(events);
|
||||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||||
int idx = pid[0];
|
int idx = pid[0];
|
||||||
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
|
io_preds[idx] = Loss::PredTransform(io_preds[idx]);
|
||||||
});
|
});
|
||||||
}).wait();
|
});
|
||||||
|
}, io_preds);
|
||||||
|
qu_.wait_and_throw();
|
||||||
}
|
}
|
||||||
|
|
||||||
float ProbToMargin(float base_score) const override {
|
float ProbToMargin(float base_score) const override {
|
||||||
@ -163,6 +188,8 @@ class RegLossObj : public ObjFunction {
|
|||||||
sycl::DeviceManager device_manager;
|
sycl::DeviceManager device_manager;
|
||||||
|
|
||||||
mutable ::sycl::queue qu_;
|
mutable ::sycl::queue qu_;
|
||||||
|
static constexpr size_t kBatchSize = 1u << 22;
|
||||||
|
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,
|
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user