[SYC]. Implementation of HostDeviceVector (#10842)
This commit is contained in:
committed by
GitHub
parent
bc69a3e877
commit
2179baa50c
@@ -39,7 +39,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
|
||||
void InitBuffers(const std::vector<int>& sample_rate) const {
|
||||
if (!are_buffs_init) {
|
||||
batch_processor_.InitBuffers(&qu_, sample_rate);
|
||||
batch_processor_.InitBuffers(qu_, sample_rate);
|
||||
are_buffs_init = true;
|
||||
}
|
||||
}
|
||||
@@ -88,7 +88,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
const bst_float* weights) {
|
||||
const size_t wg_size = 32;
|
||||
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
|
||||
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
|
||||
return linalg::GroupWiseKernel(qu_, &flag, events, {nwgs, wg_size},
|
||||
[=] (size_t idx, auto flag) {
|
||||
const bst_float* pred = preds + idx * nclass;
|
||||
|
||||
@@ -133,7 +133,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
*(info.labels.Data()),
|
||||
info.weights_);
|
||||
}
|
||||
qu_.wait_and_throw();
|
||||
qu_->wait_and_throw();
|
||||
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
|
||||
@@ -160,7 +160,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
|
||||
|
||||
if (prob) {
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
qu_->submit([&](::sycl::handler& cgh) {
|
||||
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
@@ -171,7 +171,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
} else {
|
||||
::sycl::buffer<bst_float, 1> max_preds_buf(max_preds_.HostPointer(), max_preds_.Size());
|
||||
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
qu_->submit([&](::sycl::handler& cgh) {
|
||||
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto max_preds_acc = max_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
@@ -215,7 +215,7 @@ class SoftmaxMultiClassObj : public ObjFunction {
|
||||
|
||||
sycl::DeviceManager device_manager;
|
||||
|
||||
mutable ::sycl::queue qu_;
|
||||
mutable ::sycl::queue* qu_;
|
||||
static constexpr size_t kBatchSize = 1u << 22;
|
||||
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
|
||||
};
|
||||
|
||||
@@ -48,7 +48,7 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
void InitBuffers() const {
|
||||
if (!are_buffs_init) {
|
||||
batch_processor_.InitBuffers(&qu_, {1, 1, 1, 1});
|
||||
batch_processor_.InitBuffers(qu_, {1, 1, 1, 1});
|
||||
are_buffs_init = true;
|
||||
}
|
||||
}
|
||||
@@ -58,13 +58,16 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
qu_ = device_manager.GetQueue(ctx_->Device());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
xgboost::linalg::Matrix<GradientPair>* out_gpair) override {
|
||||
if (qu_ == nullptr) {
|
||||
LOG(WARNING) << ctx_->Device();
|
||||
qu_ = device_manager.GetQueue(ctx_->Device());
|
||||
}
|
||||
if (info.labels.Size() == 0) return;
|
||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||
<< " " << "labels are not correctly provided"
|
||||
@@ -97,7 +100,7 @@ class RegLossObj : public ObjFunction {
|
||||
const bst_float* weights) {
|
||||
const size_t wg_size = 32;
|
||||
const size_t nwgs = ndata / wg_size + (ndata % wg_size > 0);
|
||||
return linalg::GroupWiseKernel(&qu_, &flag, events, {nwgs, wg_size},
|
||||
return linalg::GroupWiseKernel(qu_, &flag, events, {nwgs, wg_size},
|
||||
[=] (size_t idx, auto flag) {
|
||||
const bst_float pred = Loss::PredTransform(preds[idx]);
|
||||
bst_float weight = is_null_weight ? 1.0f : weights[idx/n_targets];
|
||||
@@ -129,7 +132,7 @@ class RegLossObj : public ObjFunction {
|
||||
*(info.labels.Data()),
|
||||
info.weights_);
|
||||
}
|
||||
qu_.wait_and_throw();
|
||||
qu_->wait_and_throw();
|
||||
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
@@ -142,6 +145,10 @@ class RegLossObj : public ObjFunction {
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) const override {
|
||||
if (qu_ == nullptr) {
|
||||
LOG(WARNING) << ctx_->Device();
|
||||
qu_ = device_manager.GetQueue(ctx_->Device());
|
||||
}
|
||||
size_t const ndata = io_preds->Size();
|
||||
if (ndata == 0) return;
|
||||
InitBuffers();
|
||||
@@ -149,7 +156,7 @@ class RegLossObj : public ObjFunction {
|
||||
batch_processor_.Calculate([=] (const std::vector<::sycl::event>& events,
|
||||
size_t ndata,
|
||||
bst_float* io_preds) {
|
||||
return qu_.submit([&](::sycl::handler& cgh) {
|
||||
return qu_->submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(events);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
@@ -157,7 +164,7 @@ class RegLossObj : public ObjFunction {
|
||||
});
|
||||
});
|
||||
}, io_preds);
|
||||
qu_.wait_and_throw();
|
||||
qu_->wait_and_throw();
|
||||
}
|
||||
|
||||
float ProbToMargin(float base_score) const override {
|
||||
@@ -187,7 +194,7 @@ class RegLossObj : public ObjFunction {
|
||||
xgboost::obj::RegLossParam param_;
|
||||
sycl::DeviceManager device_manager;
|
||||
|
||||
mutable ::sycl::queue qu_;
|
||||
mutable ::sycl::queue* qu_ = nullptr;
|
||||
static constexpr size_t kBatchSize = 1u << 22;
|
||||
mutable linalg::BatchProcessingHelper<GradientPair, bst_float, kBatchSize, 3> batch_processor_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user