[SYC]. Implementation of HostDeviceVector (#10842)

This commit is contained in:
Dmitry Razdoburdin
2024-09-24 22:45:17 +02:00
committed by GitHub
parent bc69a3e877
commit 2179baa50c
25 changed files with 937 additions and 282 deletions

View File

@@ -39,7 +39,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
void InitBuffers(const std::vector<int>& sample_rate) const {
if (!are_buffs_init) {
batch_processor_.InitBuffers(&qu_, sample_rate);
batch_processor_.InitBuffers(qu_, sample_rate);
are_buffs_init = true;
}
}
@@ -88,7 +88,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
const bst_float* weights) {
const size_t wg_size = 32;
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
return linalg::GroupWiseKernel(qu_, &flag, events, {nwgs, wg_size},
[=] (size_t idx, auto flag) {
const bst_float* pred = preds + idx * nclass;
@@ -133,7 +133,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
*(info.labels.Data()),
info.weights_);
}
qu_.wait_and_throw();
qu_->wait_and_throw();
if (flag == 0) {
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
@@ -160,7 +160,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
if (prob) {
qu_.submit([&](::sycl::handler& cgh) {
qu_->submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
@@ -171,7 +171,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
} else {
::sycl::buffer<bst_float, 1> max_preds_buf(max_preds_.HostPointer(), max_preds_.Size());
qu_.submit([&](::sycl::handler& cgh) {
qu_->submit([&](::sycl::handler& cgh) {
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read>(cgh);
auto max_preds_acc = max_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
@@ -215,7 +215,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
sycl::DeviceManager device_manager;
mutable ::sycl::queue qu_;
mutable ::sycl::queue* qu_;
static constexpr size_t kBatchSize = 1u << 22;
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
};

View File

@@ -48,7 +48,7 @@ class RegLossObj : public ObjFunction {
void InitBuffers() const {
if (!are_buffs_init) {
batch_processor_.InitBuffers(&qu_, {1, 1, 1, 1});
batch_processor_.InitBuffers(qu_, {1, 1, 1, 1});
are_buffs_init = true;
}
}
@@ -58,13 +58,16 @@ class RegLossObj : public ObjFunction {
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.UpdateAllowUnknown(args);
qu_ = device_manager.GetQueue(ctx_->Device());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
if (qu_ == nullptr) {
LOG(WARNING) << ctx_->Device();
qu_ = device_manager.GetQueue(ctx_->Device());
}
if (info.labels.Size() == 0) return;
CHECK_EQ(preds.Size(), info.labels.Size())
<< " " << "labels are not correctly provided"
@@ -97,7 +100,7 @@ class RegLossObj : public ObjFunction {
const bst_float* weights) {
const size_t wg_size = 32;
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
return linalg::GroupWiseKernel(qu_, &flag, events, {nwgs, wg_size},
[=] (size_t idx, auto flag) {
const bst_float pred = Loss::PredTransform(preds[idx]);
bst_float weight = is_null_weight ? 1.0f : weights[idx/n_targets];
@@ -129,7 +132,7 @@ class RegLossObj : public ObjFunction {
*(info.labels.Data()),
info.weights_);
}
qu_.wait_and_throw();
qu_->wait_and_throw();
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
@@ -142,6 +145,10 @@ class RegLossObj : public ObjFunction {
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
if (qu_ == nullptr) {
LOG(WARNING) << ctx_->Device();
qu_ = device_manager.GetQueue(ctx_->Device());
}
size_t const ndata = io_preds->Size();
if (ndata == 0) return;
InitBuffers();
@@ -149,7 +156,7 @@ class RegLossObj : public ObjFunction {
batch_processor_.Calculate([=] (const std::vector<::sycl::event>& events,
size_t ndata,
bst_float* io_preds) {
return qu_.submit([&](::sycl::handler& cgh) {
return qu_->submit([&](::sycl::handler& cgh) {
cgh.depends_on(events);
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
int idx = pid[0];
@@ -157,7 +164,7 @@ class RegLossObj : public ObjFunction {
});
});
}, io_preds);
qu_.wait_and_throw();
qu_->wait_and_throw();
}
float ProbToMargin(float base_score) const override {
@@ -187,7 +194,7 @@ class RegLossObj : public ObjFunction {
xgboost::obj::RegLossParam param_;
sycl::DeviceManager device_manager;
mutable ::sycl::queue qu_;
mutable ::sycl::queue* qu_ = nullptr;
static constexpr size_t kBatchSize = 1u << 22;
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
};