merge latest, Jan 12 2024
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
if(PLUGIN_SYCL)
|
||||
set(CMAKE_CXX_COMPILER "icpx")
|
||||
add_library(plugin_sycl OBJECT
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/objective/regression_obj.cc
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/objective/multiclass_obj.cc
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/device_manager.cc
|
||||
${xgboost_SOURCE_DIR}/plugin/sycl/predictor/predictor.cc)
|
||||
target_include_directories(plugin_sycl
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
# gRPC needs to be installed first. See README.md.
|
||||
set(protobuf_MODULE_COMPATIBLE TRUE)
|
||||
set(protobuf_BUILD_SHARED_LIBS TRUE)
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
|
||||
find_package(Protobuf CONFIG)
|
||||
if(NOT Protobuf_FOUND)
|
||||
find_package(Protobuf)
|
||||
endif()
|
||||
if(NOT Protobuf_FOUND)
|
||||
# let CMake emit error
|
||||
find_package(Protobuf CONFIG REQUIRED)
|
||||
endif()
|
||||
|
||||
find_package(gRPC CONFIG REQUIRED)
|
||||
message(STATUS "Found gRPC: ${gRPC_CONFIG}")
|
||||
|
||||
|
||||
@@ -66,13 +66,13 @@ class USMVector {
|
||||
public:
|
||||
USMVector() : size_(0), capacity_(0), data_(nullptr) {}
|
||||
|
||||
USMVector(::sycl::queue& qu, size_t size) : size_(size), capacity_(size) {
|
||||
USMVector(::sycl::queue* qu, size_t size) : size_(size), capacity_(size) {
|
||||
data_ = allocate_memory_(qu, size_);
|
||||
}
|
||||
|
||||
USMVector(::sycl::queue& qu, size_t size, T v) : size_(size), capacity_(size) {
|
||||
USMVector(::sycl::queue* qu, size_t size, T v) : size_(size), capacity_(size) {
|
||||
data_ = allocate_memory_(qu, size_);
|
||||
qu.fill(data_.get(), v, size_).wait();
|
||||
qu->fill(data_.get(), v, size_).wait();
|
||||
}
|
||||
|
||||
USMVector(::sycl::queue* qu, const std::vector<T> &vec) {
|
||||
@@ -147,25 +147,22 @@ class USMVector {
|
||||
}
|
||||
}
|
||||
|
||||
::sycl::event ResizeAsync(::sycl::queue* qu, size_t size_new, T v) {
|
||||
void Resize(::sycl::queue* qu, size_t size_new, T v, ::sycl::event* event) {
|
||||
if (size_new <= size_) {
|
||||
size_ = size_new;
|
||||
return ::sycl::event();
|
||||
} else if (size_new <= capacity_) {
|
||||
auto event = qu->fill(data_.get() + size_, v, size_new - size_);
|
||||
size_ = size_new;
|
||||
return event;
|
||||
} else {
|
||||
size_t size_old = size_;
|
||||
auto data_old = data_;
|
||||
size_ = size_new;
|
||||
capacity_ = size_new;
|
||||
data_ = allocate_memory_(qu, size_);
|
||||
::sycl::event event;
|
||||
if (size_old > 0) {
|
||||
event = qu->memcpy(data_.get(), data_old.get(), sizeof(T) * size_old);
|
||||
*event = qu->memcpy(data_.get(), data_old.get(), sizeof(T) * size_old, *event);
|
||||
}
|
||||
return qu->fill(data_.get() + size_old, v, size_new - size_old, event);
|
||||
*event = qu->fill(data_.get() + size_old, v, size_new - size_old, *event);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,7 +207,7 @@ struct DeviceMatrix {
|
||||
DMatrix* p_mat; // Pointer to the original matrix on the host
|
||||
::sycl::queue qu_;
|
||||
USMVector<size_t> row_ptr;
|
||||
USMVector<Entry> data;
|
||||
USMVector<Entry, MemoryType::on_device> data;
|
||||
size_t total_offset;
|
||||
|
||||
DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) {
|
||||
@@ -238,8 +235,9 @@ struct DeviceMatrix {
|
||||
for (size_t i = 0; i < batch_size; i++)
|
||||
row_ptr[i + batch.base_rowid] += batch.base_rowid;
|
||||
}
|
||||
std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size],
|
||||
data.Data() + data_offset);
|
||||
qu.memcpy(data.Data() + data_offset,
|
||||
data_vec.data(),
|
||||
offset_vec[batch_size] * sizeof(Entry)).wait();
|
||||
data_offset += offset_vec[batch_size];
|
||||
}
|
||||
}
|
||||
|
||||
210
plugin/sycl/objective/multiclass_obj.cc
Normal file
210
plugin/sycl/objective/multiclass_obj.cc
Normal file
@@ -0,0 +1,210 @@
|
||||
/*!
|
||||
* Copyright 2015-2023 by Contributors
|
||||
* \file multiclass_obj.cc
|
||||
* \brief Definition of multi-class classification objectives.
|
||||
*/
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <rabit/rabit.h>
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "xgboost/parameter.h"
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#include "xgboost/data.h"
|
||||
#include "../../src/common/math.h"
|
||||
#pragma GCC diagnostic pop
|
||||
#include "xgboost/logging.h"
|
||||
#include "xgboost/objective.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
#include "../../../src/objective/multiclass_param.h"
|
||||
|
||||
#include "../device_manager.h"
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl);
|
||||
|
||||
class SoftmaxMultiClassObj : public ObjFunction {
|
||||
public:
|
||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||
: output_prob_(output_prob) {}
|
||||
|
||||
void Configure(Args const& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
qu_ = device_manager.GetQueue(ctx_->Device());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
linalg::Matrix<GradientPair>* out_gpair) override {
|
||||
if (preds.Size() == 0) return;
|
||||
if (info.labels.Size() == 0) return;
|
||||
|
||||
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels.Size()))
|
||||
<< "SoftmaxMultiClassObj: label size and pred size does not match.\n"
|
||||
<< "label.Size() * num_class: "
|
||||
<< info.labels.Size() * static_cast<size_t>(param_.num_class) << "\n"
|
||||
<< "num_class: " << param_.num_class << "\n"
|
||||
<< "preds.Size(): " << preds.Size();
|
||||
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
|
||||
|
||||
out_gpair->Reshape(info.num_row_, static_cast<std::uint64_t>(nclass));
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), ndata)
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
|
||||
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
|
||||
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
|
||||
out_gpair->Size());
|
||||
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
|
||||
is_null_weight ? 1 : info.weights_.Size());
|
||||
|
||||
int flag = 1;
|
||||
{
|
||||
::sycl::buffer<int, 1> flag_buf(&flag, 1);
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
|
||||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
|
||||
bst_float const * point = &preds_acc[idx * nclass];
|
||||
|
||||
// Part of Softmax function
|
||||
bst_float wmax = std::numeric_limits<bst_float>::min();
|
||||
for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); }
|
||||
float wsum = 0.0f;
|
||||
for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); }
|
||||
auto label = labels_acc[idx];
|
||||
if (label < 0 || label >= nclass) {
|
||||
flag_buf_acc[0] = 0;
|
||||
label = 0;
|
||||
}
|
||||
bst_float wt = is_null_weight ? 1.0f : weights_acc[idx];
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
|
||||
const float eps = 1e-16f;
|
||||
const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps);
|
||||
p = label == k ? p - 1.0f : p;
|
||||
out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h);
|
||||
}
|
||||
});
|
||||
}).wait();
|
||||
}
|
||||
// flag_buf is destroyed, content is copyed to the "flag"
|
||||
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class).";
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) const override {
|
||||
this->Transform(io_preds, output_prob_);
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
this->Transform(io_preds, true);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "mlogloss";
|
||||
}
|
||||
|
||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) const {
|
||||
if (io_preds->Size() == 0) return;
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
|
||||
max_preds_.Resize(ndata);
|
||||
|
||||
{
|
||||
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
|
||||
|
||||
if (prob) {
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
auto it = io_preds_acc.begin() + idx * nclass;
|
||||
common::Softmax(it, it + nclass);
|
||||
});
|
||||
}).wait();
|
||||
} else {
|
||||
::sycl::buffer<bst_float, 1> max_preds_buf(max_preds_.HostPointer(), max_preds_.Size());
|
||||
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto max_preds_acc = max_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
auto it = io_preds_acc.begin() + idx * nclass;
|
||||
max_preds_acc[idx] = common::FindMaxIndex(it, it + nclass) - it;
|
||||
});
|
||||
}).wait();
|
||||
}
|
||||
}
|
||||
|
||||
if (!prob) {
|
||||
io_preds->Resize(max_preds_.Size());
|
||||
io_preds->Copy(max_preds_);
|
||||
}
|
||||
}
|
||||
|
||||
struct ObjInfo Task() const override {return {ObjInfo::kClassification}; }
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
if (this->output_prob_) {
|
||||
out["name"] = String("multi:softprob");
|
||||
} else {
|
||||
out["name"] = String("multi:softmax");
|
||||
}
|
||||
out["softmax_multiclass_param"] = ToJson(param_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
FromJson(in["softmax_multiclass_param"], ¶m_);
|
||||
}
|
||||
|
||||
private:
|
||||
// output probability
|
||||
bool output_prob_;
|
||||
// parameter
|
||||
xgboost::obj::SoftmaxMultiClassParam param_;
|
||||
// Cache for max_preds
|
||||
mutable HostDeviceVector<bst_float> max_preds_;
|
||||
|
||||
sycl::DeviceManager device_manager;
|
||||
|
||||
mutable ::sycl::queue qu_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl")
|
||||
.describe("Softmax for multi-class classification, output class index.")
|
||||
.set_body([]() { return new SoftmaxMultiClassObj(false); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob_sycl")
|
||||
.describe("Softmax for multi-class classification, output probability distribution.")
|
||||
.set_body([]() { return new SoftmaxMultiClassObj(true); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
197
plugin/sycl/objective/regression_obj.cc
Normal file
197
plugin/sycl/objective/regression_obj.cc
Normal file
@@ -0,0 +1,197 @@
|
||||
/*!
|
||||
* Copyright 2015-2023 by Contributors
|
||||
* \file regression_obj.cc
|
||||
* \brief Definition of regression objectives.
|
||||
*/
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#pragma GCC diagnostic ignored "-W#pragma-messages"
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#pragma GCC diagnostic pop
|
||||
#include <rabit/rabit.h>
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "xgboost/host_device_vector.h"
|
||||
#include "xgboost/json.h"
|
||||
#include "xgboost/parameter.h"
|
||||
#include "xgboost/span.h"
|
||||
|
||||
#include "../../src/common/transform.h"
|
||||
#include "../../src/common/common.h"
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
#include "../../../src/objective/regression_loss.h"
|
||||
#pragma GCC diagnostic pop
|
||||
#include "../../../src/objective/regression_param.h"
|
||||
|
||||
#include "../device_manager.h"
|
||||
|
||||
#include <CL/sycl.hpp>
|
||||
|
||||
namespace xgboost {
|
||||
namespace sycl {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_sycl);
|
||||
|
||||
template<typename Loss>
|
||||
class RegLossObj : public ObjFunction {
|
||||
protected:
|
||||
HostDeviceVector<int> label_correct_;
|
||||
|
||||
public:
|
||||
RegLossObj() = default;
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.UpdateAllowUnknown(args);
|
||||
qu_ = device_manager.GetQueue(ctx_->Device());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
linalg::Matrix<GradientPair>* out_gpair) override {
|
||||
if (info.labels.Size() == 0) return;
|
||||
CHECK_EQ(preds.Size(), info.labels.Size())
|
||||
<< " " << "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", "
|
||||
<< "Loss: " << Loss::Name();
|
||||
|
||||
size_t const ndata = preds.Size();
|
||||
auto const n_targets = this->Targets(info);
|
||||
out_gpair->Reshape(info.num_row_, n_targets);
|
||||
|
||||
// TODO(razdoburdin): add label_correct check
|
||||
label_correct_.Resize(1);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
|
||||
::sycl::buffer<bst_float, 1> preds_buf(preds.HostPointer(), preds.Size());
|
||||
::sycl::buffer<bst_float, 1> labels_buf(info.labels.Data()->HostPointer(), info.labels.Size());
|
||||
::sycl::buffer<GradientPair, 1> out_gpair_buf(out_gpair->Data()->HostPointer(),
|
||||
out_gpair->Size());
|
||||
::sycl::buffer<bst_float, 1> weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(),
|
||||
is_null_weight ? 1 : info.weights_.Size());
|
||||
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
if (!is_null_weight) {
|
||||
CHECK_EQ(info.weights_.Size(), info.labels.Shape(0))
|
||||
<< "Number of weights should be equal to number of data points.";
|
||||
}
|
||||
|
||||
int flag = 1;
|
||||
{
|
||||
::sycl::buffer<int, 1> flag_buf(&flag, 1);
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh);
|
||||
auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh);
|
||||
auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
bst_float p = Loss::PredTransform(preds_acc[idx]);
|
||||
bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets];
|
||||
bst_float label = labels_acc[idx];
|
||||
if (label == 1.0f) {
|
||||
w *= scale_pos_weight;
|
||||
}
|
||||
if (!Loss::CheckLabel(label)) {
|
||||
// If there is an incorrect label, the host code will know.
|
||||
flag_buf_acc[0] = 0;
|
||||
}
|
||||
out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
|
||||
Loss::SecondOrderGradient(p, label) * w);
|
||||
});
|
||||
}).wait();
|
||||
}
|
||||
// flag_buf is destroyed, content is copyed to the "flag"
|
||||
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return Loss::DefaultEvalMetric();
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) const override {
|
||||
size_t const ndata = io_preds->Size();
|
||||
if (ndata == 0) return;
|
||||
::sycl::buffer<bst_float, 1> io_preds_buf(io_preds->HostPointer(), io_preds->Size());
|
||||
|
||||
qu_.submit([&](::sycl::handler& cgh) {
|
||||
auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) {
|
||||
int idx = pid[0];
|
||||
io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]);
|
||||
});
|
||||
}).wait();
|
||||
}
|
||||
|
||||
float ProbToMargin(float base_score) const override {
|
||||
return Loss::ProbToMargin(base_score);
|
||||
}
|
||||
|
||||
struct ObjInfo Task() const override {
|
||||
return Loss::Info();
|
||||
};
|
||||
|
||||
uint32_t Targets(MetaInfo const& info) const override {
|
||||
// Multi-target regression.
|
||||
return std::max(static_cast<size_t>(1), info.labels.Shape(1));
|
||||
}
|
||||
|
||||
void SaveConfig(Json* p_out) const override {
|
||||
auto& out = *p_out;
|
||||
out["name"] = String(Loss::Name());
|
||||
out["reg_loss_param"] = ToJson(param_);
|
||||
}
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
FromJson(in["reg_loss_param"], ¶m_);
|
||||
}
|
||||
|
||||
protected:
|
||||
xgboost::obj::RegLossParam param_;
|
||||
sycl::DeviceManager device_manager;
|
||||
|
||||
mutable ::sycl::queue qu_;
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression,
|
||||
std::string(xgboost::obj::LinearSquareLoss::Name()) + "_sycl")
|
||||
.describe("Regression with squared error with SYCL backend.")
|
||||
.set_body([]() { return new RegLossObj<xgboost::obj::LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SquareLogError,
|
||||
std::string(xgboost::obj::SquaredLogError::Name()) + "_sycl")
|
||||
.describe("Regression with root mean squared logarithmic error with SYCL backend.")
|
||||
.set_body([]() { return new RegLossObj<xgboost::obj::SquaredLogError>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression,
|
||||
std::string(xgboost::obj::LogisticRegression::Name()) + "_sycl")
|
||||
.describe("Logistic regression for probability regression task with SYCL backend.")
|
||||
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification,
|
||||
std::string(xgboost::obj::LogisticClassification::Name()) + "_sycl")
|
||||
.describe("Logistic regression for binary classification task with SYCL backend.")
|
||||
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw,
|
||||
std::string(xgboost::obj::LogisticRaw::Name()) + "_sycl")
|
||||
.describe("Logistic regression for classification, output score "
|
||||
"before logistic transformation with SYCL backend.")
|
||||
.set_body([]() { return new RegLossObj<xgboost::obj::LogisticRaw>(); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace sycl
|
||||
} // namespace xgboost
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "xgboost/tree_model.h"
|
||||
#include "xgboost/predictor.h"
|
||||
#include "xgboost/tree_updater.h"
|
||||
#include "../../../src/common/timer.h"
|
||||
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||
@@ -36,36 +37,37 @@ namespace predictor {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(predictor_sycl);
|
||||
|
||||
/* Wrapper for descriptor of a tree node */
|
||||
struct DeviceNode {
|
||||
DeviceNode()
|
||||
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
|
||||
|
||||
union NodeValue {
|
||||
float leaf_weight;
|
||||
float fvalue;
|
||||
};
|
||||
union NodeValue {
|
||||
float leaf_weight;
|
||||
float fvalue;
|
||||
};
|
||||
|
||||
class Node {
|
||||
int fidx;
|
||||
int left_child_idx;
|
||||
int right_child_idx;
|
||||
NodeValue val;
|
||||
|
||||
explicit DeviceNode(const RegTree::Node& n) {
|
||||
this->left_child_idx = n.LeftChild();
|
||||
this->right_child_idx = n.RightChild();
|
||||
this->fidx = n.SplitIndex();
|
||||
public:
|
||||
explicit Node(const RegTree::Node& n) {
|
||||
left_child_idx = n.LeftChild();
|
||||
right_child_idx = n.RightChild();
|
||||
fidx = n.SplitIndex();
|
||||
if (n.DefaultLeft()) {
|
||||
fidx |= (1U << 31);
|
||||
}
|
||||
|
||||
if (n.IsLeaf()) {
|
||||
this->val.leaf_weight = n.LeafValue();
|
||||
val.leaf_weight = n.LeafValue();
|
||||
} else {
|
||||
this->val.fvalue = n.SplitCond();
|
||||
val.fvalue = n.SplitCond();
|
||||
}
|
||||
}
|
||||
|
||||
int LeftChildIdx() const {return left_child_idx; }
|
||||
|
||||
int RightChildIdx() const {return right_child_idx; }
|
||||
|
||||
bool IsLeaf() const { return left_child_idx == -1; }
|
||||
|
||||
int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
||||
@@ -74,9 +76,9 @@ struct DeviceNode {
|
||||
|
||||
int MissingIdx() const {
|
||||
if (MissingLeft()) {
|
||||
return this->left_child_idx;
|
||||
return left_child_idx;
|
||||
} else {
|
||||
return this->right_child_idx;
|
||||
return right_child_idx;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,105 +87,79 @@ struct DeviceNode {
|
||||
float GetWeight() const { return val.leaf_weight; }
|
||||
};
|
||||
|
||||
/* SYCL implementation of a device model,
|
||||
* storing tree structure in USM buffers to provide access from device kernels
|
||||
*/
|
||||
class DeviceModel {
|
||||
public:
|
||||
::sycl::queue qu_;
|
||||
USMVector<DeviceNode> nodes_;
|
||||
USMVector<size_t> tree_segments_;
|
||||
USMVector<int> tree_group_;
|
||||
size_t tree_beg_;
|
||||
size_t tree_end_;
|
||||
int num_group_;
|
||||
USMVector<Node> nodes;
|
||||
USMVector<size_t> first_node_position;
|
||||
USMVector<int> tree_group;
|
||||
size_t tree_beg;
|
||||
size_t tree_end;
|
||||
int num_group;
|
||||
|
||||
DeviceModel() {}
|
||||
|
||||
~DeviceModel() {}
|
||||
|
||||
void Init(::sycl::queue qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||
qu_ = qu;
|
||||
|
||||
tree_segments_.Resize(&qu_, (tree_end - tree_begin) + 1);
|
||||
int sum = 0;
|
||||
tree_segments_[0] = sum;
|
||||
void Init(::sycl::queue* qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||
int n_nodes = 0;
|
||||
first_node_position.Resize(qu, (tree_end - tree_begin) + 1);
|
||||
first_node_position[0] = n_nodes;
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
if (model.trees[tree_idx]->HasCategoricalSplit()) {
|
||||
LOG(FATAL) << "Categorical features are not yet supported by sycl";
|
||||
}
|
||||
sum += model.trees[tree_idx]->GetNodes().size();
|
||||
tree_segments_[tree_idx - tree_begin + 1] = sum;
|
||||
n_nodes += model.trees[tree_idx]->GetNodes().size();
|
||||
first_node_position[tree_idx - tree_begin + 1] = n_nodes;
|
||||
}
|
||||
|
||||
nodes_.Resize(&qu_, sum);
|
||||
nodes.Resize(qu, n_nodes);
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
||||
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++)
|
||||
nodes_[node_idx + tree_segments_[tree_idx - tree_begin]] =
|
||||
static_cast<DeviceNode>(src_nodes[node_idx]);
|
||||
size_t n_nodes_shift = first_node_position[tree_idx - tree_begin];
|
||||
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++) {
|
||||
nodes[node_idx + n_nodes_shift] = static_cast<Node>(src_nodes[node_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
tree_group_.Resize(&qu_, model.tree_info.size());
|
||||
tree_group.Resize(qu, model.tree_info.size());
|
||||
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
|
||||
tree_group_[tree_idx] = model.tree_info[tree_idx];
|
||||
tree_group[tree_idx] = model.tree_info[tree_idx];
|
||||
|
||||
tree_beg_ = tree_begin;
|
||||
tree_end_ = tree_end;
|
||||
num_group_ = model.learner_model_param->num_output_group;
|
||||
tree_beg = tree_begin;
|
||||
tree_end = tree_end;
|
||||
num_group = model.learner_model_param->num_output_group;
|
||||
}
|
||||
};
|
||||
|
||||
float GetFvalue(int ridx, int fidx, Entry* data, size_t* row_ptr, bool* is_missing) {
|
||||
// Binary search
|
||||
auto begin_ptr = data + row_ptr[ridx];
|
||||
auto end_ptr = data + row_ptr[ridx + 1];
|
||||
Entry* previous_middle = nullptr;
|
||||
while (end_ptr != begin_ptr) {
|
||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
||||
if (middle == previous_middle) {
|
||||
break;
|
||||
float GetLeafWeight(const Node* nodes, const float* fval_buff, const uint8_t* miss_buff) {
|
||||
const Node* node = nodes;
|
||||
while (!node->IsLeaf()) {
|
||||
if (miss_buff[node->GetFidx()] == 1) {
|
||||
node = nodes + node->MissingIdx();
|
||||
} else {
|
||||
previous_middle = middle;
|
||||
}
|
||||
|
||||
if (middle->index == fidx) {
|
||||
*is_missing = false;
|
||||
return middle->fvalue;
|
||||
} else if (middle->index < fidx) {
|
||||
begin_ptr = middle;
|
||||
} else {
|
||||
end_ptr = middle;
|
||||
}
|
||||
}
|
||||
*is_missing = true;
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
float GetLeafWeight(int ridx, const DeviceNode* tree, Entry* data, size_t* row_ptr) {
|
||||
DeviceNode n = tree[0];
|
||||
int node_id = 0;
|
||||
bool is_missing;
|
||||
while (!n.IsLeaf()) {
|
||||
float fvalue = GetFvalue(ridx, n.GetFidx(), data, row_ptr, &is_missing);
|
||||
// Missing value
|
||||
if (is_missing) {
|
||||
n = tree[n.MissingIdx()];
|
||||
} else {
|
||||
if (fvalue < n.GetFvalue()) {
|
||||
node_id = n.left_child_idx;
|
||||
n = tree[n.left_child_idx];
|
||||
const float fvalue = fval_buff[node->GetFidx()];
|
||||
if (fvalue < node->GetFvalue()) {
|
||||
node = nodes + node->LeftChildIdx();
|
||||
} else {
|
||||
node_id = n.right_child_idx;
|
||||
n = tree[n.right_child_idx];
|
||||
node = nodes + node->RightChildIdx();
|
||||
}
|
||||
}
|
||||
}
|
||||
return n.GetWeight();
|
||||
return node->GetWeight();
|
||||
}
|
||||
|
||||
void DevicePredictInternal(::sycl::queue qu,
|
||||
sycl::DeviceMatrix* dmat,
|
||||
float GetLeafWeight(const Node* nodes, const float* fval_buff) {
|
||||
const Node* node = nodes;
|
||||
while (!node->IsLeaf()) {
|
||||
const float fvalue = fval_buff[node->GetFidx()];
|
||||
if (fvalue < node->GetFvalue()) {
|
||||
node = nodes + node->LeftChildIdx();
|
||||
} else {
|
||||
node = nodes + node->RightChildIdx();
|
||||
}
|
||||
}
|
||||
return node->GetWeight();
|
||||
}
|
||||
|
||||
template <bool any_missing>
|
||||
void DevicePredictInternal(::sycl::queue* qu,
|
||||
const sycl::DeviceMatrix& dmat,
|
||||
HostDeviceVector<float>* out_preds,
|
||||
const gbm::GBTreeModel& model,
|
||||
size_t tree_begin,
|
||||
@@ -194,43 +170,75 @@ void DevicePredictInternal(::sycl::queue qu,
|
||||
DeviceModel device_model;
|
||||
device_model.Init(qu, model, tree_begin, tree_end);
|
||||
|
||||
auto& out_preds_vec = out_preds->HostVector();
|
||||
|
||||
DeviceNode* nodes = device_model.nodes_.Data();
|
||||
::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
|
||||
size_t* tree_segments = device_model.tree_segments_.Data();
|
||||
int* tree_group = device_model.tree_group_.Data();
|
||||
size_t* row_ptr = dmat->row_ptr.Data();
|
||||
Entry* data = dmat->data.Data();
|
||||
int num_features = dmat->p_mat->Info().num_col_;
|
||||
int num_rows = dmat->row_ptr.Size() - 1;
|
||||
const Node* nodes = device_model.nodes.DataConst();
|
||||
const size_t* first_node_position = device_model.first_node_position.DataConst();
|
||||
const int* tree_group = device_model.tree_group.DataConst();
|
||||
const size_t* row_ptr = dmat.row_ptr.DataConst();
|
||||
const Entry* data = dmat.data.DataConst();
|
||||
int num_features = dmat.p_mat->Info().num_col_;
|
||||
int num_rows = dmat.row_ptr.Size() - 1;
|
||||
int num_group = model.learner_model_param->num_output_group;
|
||||
|
||||
qu.submit([&](::sycl::handler& cgh) {
|
||||
USMVector<float, MemoryType::on_device> fval_buff(qu, num_features * num_rows);
|
||||
USMVector<uint8_t, MemoryType::on_device> miss_buff;
|
||||
auto* fval_buff_ptr = fval_buff.Data();
|
||||
|
||||
std::vector<::sycl::event> events(1);
|
||||
if constexpr (any_missing) {
|
||||
miss_buff.Resize(qu, num_features * num_rows, 1, &events[0]);
|
||||
}
|
||||
auto* miss_buff_ptr = miss_buff.Data();
|
||||
|
||||
auto& out_preds_vec = out_preds->HostVector();
|
||||
::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
|
||||
events[0] = qu->submit([&](::sycl::handler& cgh) {
|
||||
cgh.depends_on(events[0]);
|
||||
auto out_predictions = out_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh);
|
||||
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::id<1> pid) {
|
||||
int global_idx = pid[0];
|
||||
if (global_idx >= num_rows) return;
|
||||
int row_idx = pid[0];
|
||||
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;
|
||||
auto* miss_buff_row_ptr = miss_buff_ptr + num_features * row_idx;
|
||||
|
||||
const Entry* first_entry = data + row_ptr[row_idx];
|
||||
const Entry* last_entry = data + row_ptr[row_idx + 1];
|
||||
for (const Entry* entry = first_entry; entry < last_entry; entry += 1) {
|
||||
fval_buff_row_ptr[entry->index] = entry->fvalue;
|
||||
if constexpr (any_missing) {
|
||||
miss_buff_row_ptr[entry->index] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (num_group == 1) {
|
||||
float sum = 0.0;
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const DeviceNode* tree = nodes + tree_segments[tree_idx - tree_begin];
|
||||
sum += GetLeafWeight(global_idx, tree, data, row_ptr);
|
||||
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
|
||||
if constexpr (any_missing) {
|
||||
sum += GetLeafWeight(first_node, fval_buff_row_ptr, miss_buff_row_ptr);
|
||||
} else {
|
||||
sum += GetLeafWeight(first_node, fval_buff_row_ptr);
|
||||
}
|
||||
}
|
||||
out_predictions[global_idx] += sum;
|
||||
out_predictions[row_idx] += sum;
|
||||
} else {
|
||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||
const DeviceNode* tree = nodes + tree_segments[tree_idx - tree_begin];
|
||||
int out_prediction_idx = global_idx * num_group + tree_group[tree_idx];
|
||||
out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr);
|
||||
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
|
||||
int out_prediction_idx = row_idx * num_group + tree_group[tree_idx];
|
||||
if constexpr (any_missing) {
|
||||
out_predictions[out_prediction_idx] +=
|
||||
GetLeafWeight(first_node, fval_buff_row_ptr, miss_buff_row_ptr);
|
||||
} else {
|
||||
out_predictions[out_prediction_idx] +=
|
||||
GetLeafWeight(first_node, fval_buff_row_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}).wait();
|
||||
});
|
||||
qu->wait();
|
||||
}
|
||||
|
||||
class Predictor : public xgboost::Predictor {
|
||||
protected:
|
||||
public:
|
||||
void InitOutPredictions(const MetaInfo& info,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const gbm::GBTreeModel& model) const override {
|
||||
@@ -263,7 +271,6 @@ class Predictor : public xgboost::Predictor {
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
explicit Predictor(Context const* context) :
|
||||
xgboost::Predictor::Predictor{context},
|
||||
cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {}
|
||||
@@ -281,7 +288,12 @@ class Predictor : public xgboost::Predictor {
|
||||
}
|
||||
|
||||
if (tree_begin < tree_end) {
|
||||
DevicePredictInternal(qu, &device_matrix, out_preds, model, tree_begin, tree_end);
|
||||
const bool any_missing = !(dmat->IsDense());
|
||||
if (any_missing) {
|
||||
DevicePredictInternal<true>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
|
||||
} else {
|
||||
DevicePredictInternal<false>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user