From 43897b829680d241491abe1ecd46b2ba9d338967 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Tue, 12 Dec 2023 07:41:50 +0100 Subject: [PATCH] Sycl implementation for objective functions (#9846) --------- Co-authored-by: Dmitry Razdoburdin <> --- include/xgboost/objective.h | 6 + plugin/CMakeLists.txt | 2 + plugin/sycl/objective/multiclass_obj.cc | 210 +++++++++ plugin/sycl/objective/regression_obj.cc | 197 +++++++++ src/objective/multiclass_obj.cu | 11 +- src/objective/multiclass_param.h | 25 ++ src/objective/objective.cc | 22 +- src/objective/regression_obj.cu | 11 +- src/objective/regression_param.h | 25 ++ tests/cpp/objective/test_multiclass_obj.cc | 19 +- tests/cpp/objective/test_multiclass_obj.h | 19 + .../cpp/objective/test_multiclass_obj_cpu.cc | 25 ++ .../cpp/objective/test_multiclass_obj_gpu.cu | 2 +- tests/cpp/objective/test_regression_obj.cc | 414 +----------------- tests/cpp/objective/test_regression_obj.h | 23 + .../cpp/objective/test_regression_obj_cpu.cc | 412 +++++++++++++++++ .../cpp/objective/test_regression_obj_gpu.cu | 2 +- tests/cpp/plugin/test_sycl_multiclass_obj.cc | 28 ++ tests/cpp/plugin/test_sycl_regression_obj.cc | 99 +++++ 19 files changed, 1129 insertions(+), 423 deletions(-) create mode 100644 plugin/sycl/objective/multiclass_obj.cc create mode 100644 plugin/sycl/objective/regression_obj.cc create mode 100644 src/objective/multiclass_param.h create mode 100644 src/objective/regression_param.h create mode 100644 tests/cpp/objective/test_multiclass_obj.h create mode 100644 tests/cpp/objective/test_multiclass_obj_cpu.cc create mode 100644 tests/cpp/objective/test_regression_obj.h create mode 100644 tests/cpp/objective/test_regression_obj_cpu.cc create mode 100644 tests/cpp/plugin/test_sycl_multiclass_obj.cc create mode 100644 tests/cpp/plugin/test_sycl_regression_obj.cc diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index d2623ee01..b88c30552 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -129,6 +129,12 @@ class ObjFunction : public Configurable { * \param name Name of the objective. */ static ObjFunction* Create(const std::string& name, Context const* ctx); + + /*! + * \brief Return sycl specific implementation name if possible. + * \param name Name of the objective. + */ + static std::string GetSyclImplementationName(const std::string& name); }; /*! diff --git a/plugin/CMakeLists.txt b/plugin/CMakeLists.txt index 0fecb4fb2..e575f1a41 100644 --- a/plugin/CMakeLists.txt +++ b/plugin/CMakeLists.txt @@ -1,6 +1,8 @@ if(PLUGIN_SYCL) set(CMAKE_CXX_COMPILER "icpx") add_library(plugin_sycl OBJECT + ${xgboost_SOURCE_DIR}/plugin/sycl/objective/regression_obj.cc + ${xgboost_SOURCE_DIR}/plugin/sycl/objective/multiclass_obj.cc ${xgboost_SOURCE_DIR}/plugin/sycl/device_manager.cc ${xgboost_SOURCE_DIR}/plugin/sycl/predictor/predictor.cc) target_include_directories(plugin_sycl diff --git a/plugin/sycl/objective/multiclass_obj.cc b/plugin/sycl/objective/multiclass_obj.cc new file mode 100644 index 000000000..3104dd35e --- /dev/null +++ b/plugin/sycl/objective/multiclass_obj.cc @@ -0,0 +1,210 @@ +/*! + * Copyright 2015-2023 by Contributors + * \file multiclass_obj.cc + * \brief Definition of multi-class classification objectives. + */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#pragma GCC diagnostic pop + +#include +#include +#include +#include + +#include "xgboost/parameter.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#include "xgboost/data.h" +#include "../../src/common/math.h" +#pragma GCC diagnostic pop +#include "xgboost/logging.h" +#include "xgboost/objective.h" +#include "xgboost/json.h" +#include "xgboost/span.h" + +#include "../../../src/objective/multiclass_param.h" + +#include "../device_manager.h" +#include + +namespace xgboost { +namespace sycl { +namespace obj { + +DMLC_REGISTRY_FILE_TAG(multiclass_obj_sycl); + +class SoftmaxMultiClassObj : public ObjFunction { + public: + explicit SoftmaxMultiClassObj(bool output_prob) + : output_prob_(output_prob) {} + + void Configure(Args const& args) override { + param_.UpdateAllowUnknown(args); + qu_ = device_manager.GetQueue(ctx_->Device()); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo& info, + int iter, + linalg::Matrix* out_gpair) override { + if (preds.Size() == 0) return; + if (info.labels.Size() == 0) return; + + CHECK(preds.Size() == (static_cast(param_.num_class) * info.labels.Size())) + << "SoftmaxMultiClassObj: label size and pred size does not match.\n" + << "label.Size() * num_class: " + << info.labels.Size() * static_cast(param_.num_class) << "\n" + << "num_class: " << param_.num_class << "\n" + << "preds.Size(): " << preds.Size(); + + const int nclass = param_.num_class; + const auto ndata = static_cast(preds.Size() / nclass); + + out_gpair->Reshape(info.num_row_, static_cast(nclass)); + + const bool is_null_weight = info.weights_.Size() == 0; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), ndata) + << "Number of weights should be equal to number of data points."; + } + + ::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); + ::sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); + ::sycl::buffer out_gpair_buf(out_gpair->Data()->HostPointer(), + out_gpair->Size()); + ::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), + is_null_weight ? 1 : info.weights_.Size()); + + int flag = 1; + { + ::sycl::buffer flag_buf(&flag, 1); + qu_.submit([&](::sycl::handler& cgh) { + auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh); + auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh); + auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh); + auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh); + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + + bst_float const * point = &preds_acc[idx * nclass]; + + // Part of Softmax function + bst_float wmax = std::numeric_limits::min(); + for (int k = 0; k < nclass; k++) { wmax = ::sycl::max(point[k], wmax); } + float wsum = 0.0f; + for (int k = 0; k < nclass; k++) { wsum += ::sycl::exp(point[k] - wmax); } + auto label = labels_acc[idx]; + if (label < 0 || label >= nclass) { + flag_buf_acc[0] = 0; + label = 0; + } + bst_float wt = is_null_weight ? 1.0f : weights_acc[idx]; + for (int k = 0; k < nclass; ++k) { + bst_float p = expf(point[k] - wmax) / static_cast(wsum); + const float eps = 1e-16f; + const bst_float h = ::sycl::max(2.0f * p * (1.0f - p) * wt, eps); + p = label == k ? p - 1.0f : p; + out_gpair_acc[idx * nclass + k] = GradientPair(p * wt, h); + } + }); + }).wait(); + } + // flag_buf is destroyed, content is copyed to the "flag" + + if (flag == 0) { + LOG(FATAL) << "SYCL::SoftmaxMultiClassObj: label must be in [0, num_class)."; + } + } + void PredTransform(HostDeviceVector* io_preds) const override { + this->Transform(io_preds, output_prob_); + } + void EvalTransform(HostDeviceVector* io_preds) override { + this->Transform(io_preds, true); + } + const char* DefaultEvalMetric() const override { + return "mlogloss"; + } + + inline void Transform(HostDeviceVector *io_preds, bool prob) const { + if (io_preds->Size() == 0) return; + const int nclass = param_.num_class; + const auto ndata = static_cast(io_preds->Size() / nclass); + max_preds_.Resize(ndata); + + { + ::sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); + + if (prob) { + qu_.submit([&](::sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + auto it = io_preds_acc.begin() + idx * nclass; + common::Softmax(it, it + nclass); + }); + }).wait(); + } else { + ::sycl::buffer max_preds_buf(max_preds_.HostPointer(), max_preds_.Size()); + + qu_.submit([&](::sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read>(cgh); + auto max_preds_acc = max_preds_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + auto it = io_preds_acc.begin() + idx * nclass; + max_preds_acc[idx] = common::FindMaxIndex(it, it + nclass) - it; + }); + }).wait(); + } + } + + if (!prob) { + io_preds->Resize(max_preds_.Size()); + io_preds->Copy(max_preds_); + } + } + + struct ObjInfo Task() const override {return {ObjInfo::kClassification}; } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + if (this->output_prob_) { + out["name"] = String("multi:softprob"); + } else { + out["name"] = String("multi:softmax"); + } + out["softmax_multiclass_param"] = ToJson(param_); + } + + void LoadConfig(Json const& in) override { + FromJson(in["softmax_multiclass_param"], ¶m_); + } + + private: + // output probability + bool output_prob_; + // parameter + xgboost::obj::SoftmaxMultiClassParam param_; + // Cache for max_preds + mutable HostDeviceVector max_preds_; + + sycl::DeviceManager device_manager; + + mutable ::sycl::queue qu_; +}; + +XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax_sycl") +.describe("Softmax for multi-class classification, output class index.") +.set_body([]() { return new SoftmaxMultiClassObj(false); }); + +XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob_sycl") +.describe("Softmax for multi-class classification, output probability distribution.") +.set_body([]() { return new SoftmaxMultiClassObj(true); }); + +} // namespace obj +} // namespace sycl +} // namespace xgboost diff --git a/plugin/sycl/objective/regression_obj.cc b/plugin/sycl/objective/regression_obj.cc new file mode 100644 index 000000000..985498717 --- /dev/null +++ b/plugin/sycl/objective/regression_obj.cc @@ -0,0 +1,197 @@ +/*! + * Copyright 2015-2023 by Contributors + * \file regression_obj.cc + * \brief Definition of regression objectives. + */ +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#pragma GCC diagnostic ignored "-W#pragma-messages" +#include +#include +#pragma GCC diagnostic pop +#include + +#include +#include +#include + +#include "xgboost/host_device_vector.h" +#include "xgboost/json.h" +#include "xgboost/parameter.h" +#include "xgboost/span.h" + +#include "../../src/common/transform.h" +#include "../../src/common/common.h" +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wtautological-constant-compare" +#include "../../../src/objective/regression_loss.h" +#pragma GCC diagnostic pop +#include "../../../src/objective/regression_param.h" + +#include "../device_manager.h" + +#include + +namespace xgboost { +namespace sycl { +namespace obj { + +DMLC_REGISTRY_FILE_TAG(regression_obj_sycl); + +template +class RegLossObj : public ObjFunction { + protected: + HostDeviceVector label_correct_; + + public: + RegLossObj() = default; + + void Configure(const std::vector >& args) override { + param_.UpdateAllowUnknown(args); + qu_ = device_manager.GetQueue(ctx_->Device()); + } + + void GetGradient(const HostDeviceVector& preds, + const MetaInfo &info, + int iter, + linalg::Matrix* out_gpair) override { + if (info.labels.Size() == 0) return; + CHECK_EQ(preds.Size(), info.labels.Size()) + << " " << "labels are not correctly provided" + << "preds.size=" << preds.Size() << ", label.size=" << info.labels.Size() << ", " + << "Loss: " << Loss::Name(); + + size_t const ndata = preds.Size(); + auto const n_targets = this->Targets(info); + out_gpair->Reshape(info.num_row_, n_targets); + + // TODO(razdoburdin): add label_correct check + label_correct_.Resize(1); + label_correct_.Fill(1); + + bool is_null_weight = info.weights_.Size() == 0; + + ::sycl::buffer preds_buf(preds.HostPointer(), preds.Size()); + ::sycl::buffer labels_buf(info.labels.Data()->HostPointer(), info.labels.Size()); + ::sycl::buffer out_gpair_buf(out_gpair->Data()->HostPointer(), + out_gpair->Size()); + ::sycl::buffer weights_buf(is_null_weight ? NULL : info.weights_.HostPointer(), + is_null_weight ? 1 : info.weights_.Size()); + + auto scale_pos_weight = param_.scale_pos_weight; + if (!is_null_weight) { + CHECK_EQ(info.weights_.Size(), info.labels.Shape(0)) + << "Number of weights should be equal to number of data points."; + } + + int flag = 1; + { + ::sycl::buffer flag_buf(&flag, 1); + qu_.submit([&](::sycl::handler& cgh) { + auto preds_acc = preds_buf.get_access<::sycl::access::mode::read>(cgh); + auto labels_acc = labels_buf.get_access<::sycl::access::mode::read>(cgh); + auto weights_acc = weights_buf.get_access<::sycl::access::mode::read>(cgh); + auto out_gpair_acc = out_gpair_buf.get_access<::sycl::access::mode::write>(cgh); + auto flag_buf_acc = flag_buf.get_access<::sycl::access::mode::write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + bst_float p = Loss::PredTransform(preds_acc[idx]); + bst_float w = is_null_weight ? 1.0f : weights_acc[idx/n_targets]; + bst_float label = labels_acc[idx]; + if (label == 1.0f) { + w *= scale_pos_weight; + } + if (!Loss::CheckLabel(label)) { + // If there is an incorrect label, the host code will know. + flag_buf_acc[0] = 0; + } + out_gpair_acc[idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w, + Loss::SecondOrderGradient(p, label) * w); + }); + }).wait(); + } + // flag_buf is destroyed, content is copyed to the "flag" + + if (flag == 0) { + LOG(FATAL) << Loss::LabelErrorMsg(); + } + } + + public: + const char* DefaultEvalMetric() const override { + return Loss::DefaultEvalMetric(); + } + + void PredTransform(HostDeviceVector *io_preds) const override { + size_t const ndata = io_preds->Size(); + if (ndata == 0) return; + ::sycl::buffer io_preds_buf(io_preds->HostPointer(), io_preds->Size()); + + qu_.submit([&](::sycl::handler& cgh) { + auto io_preds_acc = io_preds_buf.get_access<::sycl::access::mode::read_write>(cgh); + cgh.parallel_for<>(::sycl::range<1>(ndata), [=](::sycl::id<1> pid) { + int idx = pid[0]; + io_preds_acc[idx] = Loss::PredTransform(io_preds_acc[idx]); + }); + }).wait(); + } + + float ProbToMargin(float base_score) const override { + return Loss::ProbToMargin(base_score); + } + + struct ObjInfo Task() const override { + return Loss::Info(); + }; + + uint32_t Targets(MetaInfo const& info) const override { + // Multi-target regression. + return std::max(static_cast(1), info.labels.Shape(1)); + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(Loss::Name()); + out["reg_loss_param"] = ToJson(param_); + } + + void LoadConfig(Json const& in) override { + FromJson(in["reg_loss_param"], ¶m_); + } + + protected: + xgboost::obj::RegLossParam param_; + sycl::DeviceManager device_manager; + + mutable ::sycl::queue qu_; +}; + +XGBOOST_REGISTER_OBJECTIVE(SquaredLossRegression, + std::string(xgboost::obj::LinearSquareLoss::Name()) + "_sycl") +.describe("Regression with squared error with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(SquareLogError, + std::string(xgboost::obj::SquaredLogError::Name()) + "_sycl") +.describe("Regression with root mean squared logarithmic error with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, + std::string(xgboost::obj::LogisticRegression::Name()) + "_sycl") +.describe("Logistic regression for probability regression task with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, + std::string(xgboost::obj::LogisticClassification::Name()) + "_sycl") +.describe("Logistic regression for binary classification task with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); + +XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, + std::string(xgboost::obj::LogisticRaw::Name()) + "_sycl") +.describe("Logistic regression for classification, output score " + "before logistic transformation with SYCL backend.") +.set_body([]() { return new RegLossObj(); }); + +} // namespace obj +} // namespace sycl +} // namespace xgboost diff --git a/src/objective/multiclass_obj.cu b/src/objective/multiclass_obj.cu index 38880f911..1a3df3884 100644 --- a/src/objective/multiclass_obj.cu +++ b/src/objective/multiclass_obj.cu @@ -21,6 +21,8 @@ #include "../common/math.h" #include "../common/transform.h" +#include "multiclass_param.h" + namespace xgboost { namespace obj { @@ -28,15 +30,6 @@ namespace obj { DMLC_REGISTRY_FILE_TAG(multiclass_obj_gpu); #endif // defined(XGBOOST_USE_CUDA) -struct SoftmaxMultiClassParam : public XGBoostParameter { - int num_class; - // declare parameters - DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) { - DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) - .describe("Number of output class in the multi-class classification."); - } -}; - class SoftmaxMultiClassObj : public ObjFunction { public: explicit SoftmaxMultiClassObj(bool output_prob) diff --git a/src/objective/multiclass_param.h b/src/objective/multiclass_param.h new file mode 100644 index 000000000..d1dea15fd --- /dev/null +++ b/src/objective/multiclass_param.h @@ -0,0 +1,25 @@ +/*! + * Copyright 2015-2023 by Contributors + * \file multiclass_param.h + * \brief Definition of multi-class classification parameters. + */ +#ifndef XGBOOST_OBJECTIVE_MULTICLASS_PARAM_H_ +#define XGBOOST_OBJECTIVE_MULTICLASS_PARAM_H_ + +#include "xgboost/parameter.h" + +namespace xgboost { +namespace obj { + +struct SoftmaxMultiClassParam : public XGBoostParameter { + int num_class; + // declare parameters + DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) { + DMLC_DECLARE_FIELD(num_class).set_lower_bound(1) + .describe("Number of output class in the multi-class classification."); + } +}; + +} // namespace obj +} // namespace xgboost +#endif // XGBOOST_OBJECTIVE_MULTICLASS_PARAM_H_ diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 85cd9803d..1ccf53264 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -18,7 +18,11 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg); namespace xgboost { // implement factory functions ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) { - auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name); + std::string obj_name = name; + if (ctx->IsSycl()) { + obj_name = GetSyclImplementationName(obj_name); + } + auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name); if (e == nullptr) { std::stringstream ss; for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) { @@ -32,6 +36,22 @@ ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) { return pobj; } +/* If the objective function has sycl-specific implementation, + * returns the specific implementation name. + * Otherwise return the orginal name without modifications. + */ +std::string ObjFunction::GetSyclImplementationName(const std::string& name) { + const std::string sycl_postfix = "_sycl"; + auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + sycl_postfix); + if (e != nullptr) { + // Function has specific sycl implementation + return name + sycl_postfix; + } else { + // Function hasn't specific sycl implementation + return name; + } +} + void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor* base_score) const { CHECK(base_score); base_score->Reshape(1); diff --git a/src/objective/regression_obj.cu b/src/objective/regression_obj.cu index 5627600fc..df30b354b 100644 --- a/src/objective/regression_obj.cu +++ b/src/objective/regression_obj.cu @@ -35,6 +35,8 @@ #include "xgboost/span.h" #include "xgboost/tree_model.h" // RegTree +#include "regression_param.h" + #if defined(XGBOOST_USE_CUDA) #include "../common/cuda_context.cuh" // for CUDAContext #include "../common/device_helpers.cuh" @@ -53,14 +55,7 @@ void CheckRegInputs(MetaInfo const& info, HostDeviceVector const& pre DMLC_REGISTRY_FILE_TAG(regression_obj_gpu); #endif // defined(XGBOOST_USE_CUDA) -struct RegLossParam : public XGBoostParameter { - float scale_pos_weight; - // declare parameters - DMLC_DECLARE_PARAMETER(RegLossParam) { - DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) - .describe("Scale the weight of positive examples by this factor"); - } -}; + template class RegLossObj : public FitIntercept { diff --git a/src/objective/regression_param.h b/src/objective/regression_param.h new file mode 100644 index 000000000..8f5cd7112 --- /dev/null +++ b/src/objective/regression_param.h @@ -0,0 +1,25 @@ +/*! + * Copyright 2015-2023 by Contributors + * \file multiclass_param.h + * \brief Definition of single-value regression and classification parameters. + */ +#ifndef XGBOOST_OBJECTIVE_REGRESSION_PARAM_H_ +#define XGBOOST_OBJECTIVE_REGRESSION_PARAM_H_ + +#include "xgboost/parameter.h" + +namespace xgboost { +namespace obj { + +struct RegLossParam : public XGBoostParameter { + float scale_pos_weight; + // declare parameters + DMLC_DECLARE_PARAMETER(RegLossParam) { + DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) + .describe("Scale the weight of positive examples by this factor"); + } +}; + +} // namespace obj +} // namespace xgboost +#endif // XGBOOST_OBJECTIVE_REGRESSION_PARAM_H_ diff --git a/tests/cpp/objective/test_multiclass_obj.cc b/tests/cpp/objective/test_multiclass_obj.cc index d028ef9cf..734e097b8 100644 --- a/tests/cpp/objective/test_multiclass_obj.cc +++ b/tests/cpp/objective/test_multiclass_obj.cc @@ -1,18 +1,18 @@ /*! - * Copyright 2018-2019 XGBoost contributors + * Copyright 2018-2023 XGBoost contributors */ #include #include #include "../../src/common/common.h" #include "../helpers.h" +#include "test_multiclass_obj.h" namespace xgboost { -TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestSoftmaxMultiClassObjGPair(const Context* ctx) { std::vector> args {{"num_class", "3"}}; std::unique_ptr obj { - ObjFunction::Create("multi:softmax", &ctx) + ObjFunction::Create("multi:softmax", ctx) }; obj->Configure(args); @@ -35,12 +35,11 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { ASSERT_NO_THROW(obj->DefaultEvalMetric()); } -TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { - auto ctx = MakeCUDACtx(GPUIDX); +void TestSoftmaxMultiClassBasic(const Context* ctx) { std::vector> args{ std::pair("num_class", "3")}; - std::unique_ptr obj{ObjFunction::Create("multi:softmax", &ctx)}; + std::unique_ptr obj{ObjFunction::Create("multi:softmax", ctx)}; obj->Configure(args); CheckConfigReload(obj, "multi:softmax"); @@ -56,13 +55,12 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { } } -TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestSoftprobMultiClassBasic(const Context* ctx) { std::vector> args { std::pair("num_class", "3")}; std::unique_ptr obj { - ObjFunction::Create("multi:softprob", &ctx) + ObjFunction::Create("multi:softprob", ctx) }; obj->Configure(args); CheckConfigReload(obj, "multi:softprob"); @@ -77,4 +75,5 @@ TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { EXPECT_NEAR(preds[i], out_preds[i], 0.01f); } } + } // namespace xgboost diff --git a/tests/cpp/objective/test_multiclass_obj.h b/tests/cpp/objective/test_multiclass_obj.h new file mode 100644 index 000000000..bf6f9258c --- /dev/null +++ b/tests/cpp/objective/test_multiclass_obj.h @@ -0,0 +1,19 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TEST_MULTICLASS_OBJ_H_ +#define XGBOOST_TEST_MULTICLASS_OBJ_H_ + +#include // for Context + +namespace xgboost { + +void TestSoftmaxMultiClassObjGPair(const Context* ctx); + +void TestSoftmaxMultiClassBasic(const Context* ctx); + +void TestSoftprobMultiClassBasic(const Context* ctx); + +} // namespace xgboost + +#endif // XGBOOST_TEST_MULTICLASS_OBJ_H_ diff --git a/tests/cpp/objective/test_multiclass_obj_cpu.cc b/tests/cpp/objective/test_multiclass_obj_cpu.cc new file mode 100644 index 000000000..d3cb8aa1f --- /dev/null +++ b/tests/cpp/objective/test_multiclass_obj_cpu.cc @@ -0,0 +1,25 @@ +/*! + * Copyright 2018-2023 XGBoost contributors + */ +#include +#include + +#include "../helpers.h" +#include "test_multiclass_obj.h" + +namespace xgboost { +TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestSoftmaxMultiClassObjGPair(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) { + auto ctx = MakeCUDACtx(GPUIDX); + TestSoftmaxMultiClassBasic(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestSoftprobMultiClassBasic(&ctx); +} +} // namespace xgboost diff --git a/tests/cpp/objective/test_multiclass_obj_gpu.cu b/tests/cpp/objective/test_multiclass_obj_gpu.cu index 7567d3242..f80f07ce8 100644 --- a/tests/cpp/objective/test_multiclass_obj_gpu.cu +++ b/tests/cpp/objective/test_multiclass_obj_gpu.cu @@ -1 +1 @@ -#include "test_multiclass_obj.cc" +#include "test_multiclass_obj_cpu.cc" diff --git a/tests/cpp/objective/test_regression_obj.cc b/tests/cpp/objective/test_regression_obj.cc index 35e8287b6..4bd693936 100644 --- a/tests/cpp/objective/test_regression_obj.cc +++ b/tests/cpp/objective/test_regression_obj.cc @@ -14,13 +14,15 @@ #include "xgboost/data.h" #include "xgboost/linalg.h" +#include "test_regression_obj.h" + namespace xgboost { -TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; +void TestLinearRegressionGPair(const Context* ctx) { + std::string obj_name = "reg:squarederror"; - std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx)}; + std::vector> args; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); CheckObjFunction(obj, @@ -38,13 +40,13 @@ TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { ASSERT_NO_THROW(obj->DefaultEvalMetric()); } -TEST(Objective, DeclareUnifiedTest(SquaredLog)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestSquaredLog(const Context* ctx) { + std::string obj_name = "reg:squaredlogerror"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:squaredlogerror", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "reg:squaredlogerror"); + CheckConfigReload(obj, obj_name); CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred @@ -61,42 +63,13 @@ TEST(Objective, DeclareUnifiedTest(SquaredLog)) { ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"rmsle"}); } -TEST(Objective, DeclareUnifiedTest(PseudoHuber)) { - Context ctx = MakeCUDACtx(GPUIDX); - Args args; - - std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "reg:pseudohubererror"); - - CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights - {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad - {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess - CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {}, // empty weights - {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad - {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess - ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); - - obj->Configure({{"huber_slope", "0.1"}}); - CheckConfigReload(obj, "reg:pseudohubererror"); - CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels - {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights - {-0.099388f, -0.099228f, -0.098639f, -0.089443f, 0.098639f}, // out_grad - {0.0013467f, 0.001908f, 0.004443f, 0.089443f, 0.004443f}); // out_hess -} - -TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestLogisticRegressionGPair(const Context* ctx) { + std::string obj_name = "reg:logistic"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:logistic", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "reg:logistic"); + CheckConfigReload(obj, obj_name); CheckObjFunction(obj, { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, // preds @@ -106,13 +79,13 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); // out_hess } -TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestLogisticRegressionBasic(const Context* ctx) { + std::string obj_name = "reg:logistic"; std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:logistic", &ctx)}; + std::unique_ptr obj{ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); - CheckConfigReload(obj, "reg:logistic"); + CheckConfigReload(obj, obj_name); // test label validation EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0})) @@ -135,12 +108,10 @@ TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { } } -TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); +void TestsLogisticRawGPair(const Context* ctx) { + std::string obj_name = "binary:logitraw"; std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("binary:logitraw", &ctx) - }; + std::unique_ptr obj {ObjFunction::Create(obj_name, ctx)}; obj->Configure(args); CheckObjFunction(obj, @@ -151,347 +122,4 @@ TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { {0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); } -TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("count:poisson", &ctx) - }; - - args.emplace_back("max_delta_step", "0.1f"); - obj->Configure(args); - - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f}, - {1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f}); - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - {}, // Empty weight - { 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f}, - {1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f}); -} - -TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("count:poisson", &ctx) - }; - - obj->Configure(args); - CheckConfigReload(obj, "count:poisson"); - - // test label validation - EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0})) - << "Expected error when label < 0 for PoissonRegression"; - - // test ProbToMargin - EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); - - // test PredTransform - HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; - std::vector out_preds = {1, 1.10f, 1.64f, 2.45f, 2.71f}; - obj->PredTransform(&io_preds); - auto& preds = io_preds.HostVector(); - for (int i = 0; i < static_cast(io_preds.Size()); ++i) { - EXPECT_NEAR(preds[i], out_preds[i], 0.01f); - } -} - -TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj { - ObjFunction::Create("reg:gamma", &ctx) - }; - - obj->Configure(args); - CheckObjFunction(obj, - {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {2, 2, 2, 2, 1, 1, 1, 1}, - {1, 1, 1, 1, 1, 1, 1, 1}, - {-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f}, - {2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f}); - CheckObjFunction(obj, - {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - {2, 2, 2, 2, 1, 1, 1, 1}, - {}, // Empty weight - {-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f}, - {2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f}); -} - -TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:gamma", &ctx)}; - - obj->Configure(args); - CheckConfigReload(obj, "reg:gamma"); - - // test label validation - EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {0}, {1}, {0}, {0})) - << "Expected error when label = 0 for GammaRegression"; - EXPECT_ANY_THROW(CheckObjFunction(obj, {-1}, {-1}, {1}, {-1}, {-3})) - << "Expected error when label < 0 for GammaRegression"; - - // test ProbToMargin - EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); - - // test PredTransform - HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; - std::vector out_preds = {1, 1.10f, 1.64f, 2.45f, 2.71f}; - obj->PredTransform(&io_preds); - auto& preds = io_preds.HostVector(); - for (int i = 0; i < static_cast(io_preds.Size()); ++i) { - EXPECT_NEAR(preds[i], out_preds[i], 0.01f); - } -} - -TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:tweedie", &ctx)}; - - args.emplace_back("tweedie_variance_power", "1.1f"); - obj->Configure(args); - - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, - {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, 0, 0, 0, 1, 1, 1, 1}, - {}, // Empty weight. - { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, - {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); - ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"}); -} - -#if defined(__CUDACC__) -TEST(Objective, CPU_vs_CUDA) { - Context ctx = MakeCUDACtx(GPUIDX); - - std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx)}; - linalg::Matrix cpu_out_preds; - linalg::Matrix cuda_out_preds; - - constexpr size_t kRows = 400; - constexpr size_t kCols = 100; - auto pdmat = RandomDataGenerator(kRows, kCols, 0).Seed(0).GenerateDMatrix(); - HostDeviceVector preds; - preds.Resize(kRows); - auto& h_preds = preds.HostVector(); - for (size_t i = 0; i < h_preds.size(); ++i) { - h_preds[i] = static_cast(i); - } - auto& info = pdmat->Info(); - - info.labels.Reshape(kRows); - auto& h_labels = info.labels.Data()->HostVector(); - for (size_t i = 0; i < h_labels.size(); ++i) { - h_labels[i] = 1 / static_cast(i+1); - } - - { - // CPU - ctx = ctx.MakeCPU(); - obj->GetGradient(preds, info, 0, &cpu_out_preds); - } - { - // CUDA - ctx = ctx.MakeCUDA(0); - obj->GetGradient(preds, info, 0, &cuda_out_preds); - } - - auto h_cpu_out = cpu_out_preds.HostView(); - auto h_cuda_out = cuda_out_preds.HostView(); - - float sgrad = 0; - float shess = 0; - for (size_t i = 0; i < kRows; ++i) { - sgrad += std::pow(h_cpu_out(i).GetGrad() - h_cuda_out(i).GetGrad(), 2); - shess += std::pow(h_cpu_out(i).GetHess() - h_cuda_out(i).GetHess(), 2); - } - ASSERT_NEAR(sgrad, 0.0f, kRtEps); - ASSERT_NEAR(shess, 0.0f, kRtEps); -} -#endif - -TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj{ObjFunction::Create("reg:tweedie", &ctx)}; - - obj->Configure(args); - CheckConfigReload(obj, "reg:tweedie"); - - // test label validation - EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0})) - << "Expected error when label < 0 for TweedieRegression"; - - // test ProbToMargin - EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); - EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); - - // test PredTransform - HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; - std::vector out_preds = {1, 1.10f, 1.64f, 2.45f, 2.71f}; - obj->PredTransform(&io_preds); - auto& preds = io_preds.HostVector(); - for (int i = 0; i < static_cast(io_preds.Size()); ++i) { - EXPECT_NEAR(preds[i], out_preds[i], 0.01f); - } -} - -// CoxRegression not implemented in GPU code, no need for testing. -#if !defined(__CUDACC__) -TEST(Objective, CoxRegressionGPair) { - Context ctx = MakeCUDACtx(GPUIDX); - std::vector> args; - std::unique_ptr obj{ObjFunction::Create("survival:cox", &ctx)}; - - obj->Configure(args); - CheckObjFunction(obj, - { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, - { 0, -2, -2, 2, 3, 5, -10, 100}, - { 1, 1, 1, 1, 1, 1, 1, 1}, - { 0, 0, 0, -0.799f, -0.788f, -0.590f, 0.910f, 1.006f}, - { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); -} -#endif - -TEST(Objective, DeclareUnifiedTest(AbsoluteError)) { - Context ctx = MakeCUDACtx(GPUIDX); - std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", &ctx)}; - obj->Configure({}); - CheckConfigReload(obj, "reg:absoluteerror"); - - MetaInfo info; - std::vector labels{0.f, 3.f, 2.f, 5.f, 4.f, 7.f}; - info.labels.Reshape(6, 1); - info.labels.Data()->HostVector() = labels; - info.num_row_ = labels.size(); - HostDeviceVector predt{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; - info.weights_.HostVector() = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; - - CheckObjFunction(obj, predt.HostVector(), labels, info.weights_.HostVector(), - {1.f, -1.f, 1.f, -1.f, 1.f, -1.f}, info.weights_.HostVector()); - - RegTree tree; - tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); - - HostDeviceVector position(labels.size(), 0); - auto& h_position = position.HostVector(); - for (size_t i = 0; i < labels.size(); ++i) { - if (i < labels.size() / 2) { - h_position[i] = 1; // left - } else { - h_position[i] = 2; // right - } - } - - auto& h_predt = predt.HostVector(); - for (size_t i = 0; i < h_predt.size(); ++i) { - h_predt[i] = labels[i] + i; - } - - tree::TrainParam param; - param.Init(Args{}); - auto lr = param.learning_rate; - - obj->UpdateTreeLeaf(position, info, param.learning_rate, predt, 0, &tree); - ASSERT_EQ(tree[1].LeafValue(), -1.0f * lr); - ASSERT_EQ(tree[2].LeafValue(), -4.0f * lr); -} - -TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) { - Context ctx = MakeCUDACtx(GPUIDX); - bst_target_t constexpr kTargets = 3, kRows = 16; - std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", &ctx)}; - obj->Configure({}); - - MetaInfo info; - info.num_row_ = kRows; - info.labels.Reshape(16, kTargets); - HostDeviceVector predt(info.labels.Size()); - - for (bst_target_t t{0}; t < kTargets; ++t) { - auto h_labels = info.labels.HostView().Slice(linalg::All(), t); - std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0); - - auto h_predt = - linalg::MakeTensorView(&ctx, predt.HostSpan(), kRows, kTargets).Slice(linalg::All(), t); - for (size_t i = 0; i < h_predt.Size(); ++i) { - h_predt(i) = h_labels(i) + i; - } - - HostDeviceVector position(h_labels.Size(), 0); - auto& h_position = position.HostVector(); - for (int32_t i = 0; i < 3; ++i) { - h_position[i] = ~i; // negation for sampled nodes. - } - for (size_t i = 3; i < 8; ++i) { - h_position[i] = 3; - } - // empty leaf for node 4 - for (size_t i = 8; i < 13; ++i) { - h_position[i] = 5; - } - for (size_t i = 13; i < h_labels.Size(); ++i) { - h_position[i] = 6; - } - - RegTree tree; - tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); - tree.ExpandNode(1, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); - tree.ExpandNode(2, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); - ASSERT_EQ(tree.GetNumLeaves(), 4); - - auto empty_leaf = tree[4].LeafValue(); - - tree::TrainParam param; - param.Init(Args{}); - auto lr = param.learning_rate; - - obj->UpdateTreeLeaf(position, info, lr, predt, t, &tree); - ASSERT_EQ(tree[3].LeafValue(), -5.0f * lr); - ASSERT_EQ(tree[4].LeafValue(), empty_leaf * lr); - ASSERT_EQ(tree[5].LeafValue(), -10.0f * lr); - ASSERT_EQ(tree[6].LeafValue(), -14.0f * lr); - } -} - -TEST(Adaptive, DeclareUnifiedTest(MissingLeaf)) { - std::vector missing{1, 3}; - - std::vector h_nidx = {2, 4, 5}; - std::vector h_nptr = {0, 4, 8, 16}; - - obj::detail::FillMissingLeaf(missing, &h_nidx, &h_nptr); - - ASSERT_EQ(h_nidx[0], missing[0]); - ASSERT_EQ(h_nidx[2], missing[1]); - ASSERT_EQ(h_nidx[1], 2); - ASSERT_EQ(h_nidx[3], 4); - ASSERT_EQ(h_nidx[4], 5); - - ASSERT_EQ(h_nptr[0], 0); - ASSERT_EQ(h_nptr[1], 0); // empty - ASSERT_EQ(h_nptr[2], 4); - ASSERT_EQ(h_nptr[3], 4); // empty - ASSERT_EQ(h_nptr[4], 8); - ASSERT_EQ(h_nptr[5], 16); -} } // namespace xgboost diff --git a/tests/cpp/objective/test_regression_obj.h b/tests/cpp/objective/test_regression_obj.h new file mode 100644 index 000000000..41f7c370e --- /dev/null +++ b/tests/cpp/objective/test_regression_obj.h @@ -0,0 +1,23 @@ +/** + * Copyright 2020-2023 by XGBoost Contributors + */ +#ifndef XGBOOST_TEST_REGRESSION_OBJ_H_ +#define XGBOOST_TEST_REGRESSION_OBJ_H_ + +#include // for Context + +namespace xgboost { + +void TestLinearRegressionGPair(const Context* ctx); + +void TestSquaredLog(const Context* ctx); + +void TestLogisticRegressionGPair(const Context* ctx); + +void TestLogisticRegressionBasic(const Context* ctx); + +void TestsLogisticRawGPair(const Context* ctx); + +} // namespace xgboost + +#endif // XGBOOST_TEST_REGRESSION_OBJ_H_ diff --git a/tests/cpp/objective/test_regression_obj_cpu.cc b/tests/cpp/objective/test_regression_obj_cpu.cc new file mode 100644 index 000000000..3613d0d90 --- /dev/null +++ b/tests/cpp/objective/test_regression_obj_cpu.cc @@ -0,0 +1,412 @@ +/*! + * Copyright 2018-2023 XGBoost contributors + */ +#include +#include +#include + +#include "../../../src/objective/adaptive.h" +#include "../../../src/tree/param.h" // for TrainParam +#include "../helpers.h" + +#include "test_regression_obj.h" + +namespace xgboost { +TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestLinearRegressionGPair(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(SquaredLog)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestSquaredLog(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(PseudoHuber)) { + Context ctx = MakeCUDACtx(GPUIDX); + Args args; + + std::unique_ptr obj{ObjFunction::Create("reg:pseudohubererror", &ctx)}; + obj->Configure(args); + CheckConfigReload(obj, "reg:pseudohubererror"); + + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights + {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad + {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {}, // empty weights + {-0.668965f, -0.624695f, -0.514496f, -0.196116f, 0.514496f}, // out_grad + {0.410660f, 0.476140f, 0.630510f, 0.9428660f, 0.630510f}); // out_hess + ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"mphe"}); + + obj->Configure({{"huber_slope", "0.1"}}); + CheckConfigReload(obj, "reg:pseudohubererror"); + CheckObjFunction(obj, {0.1f, 0.2f, 0.4f, 0.8f, 1.6f}, // pred + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // labels + {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, // weights + {-0.099388f, -0.099228f, -0.098639f, -0.089443f, 0.098639f}, // out_grad + {0.0013467f, 0.001908f, 0.004443f, 0.089443f, 0.004443f}); // out_hess +} + +TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestLogisticRegressionGPair(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestLogisticRegressionBasic(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + TestsLogisticRawGPair(&ctx); +} + +TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj { + ObjFunction::Create("count:poisson", &ctx) + }; + + args.emplace_back("max_delta_step", "0.1f"); + obj->Configure(args); + + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + { 1, 1, 1, 1, 1, 1, 1, 1}, + { 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f}, + {1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f}); + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + {}, // Empty weight + { 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f}, + {1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f}); +} + +TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj { + ObjFunction::Create("count:poisson", &ctx) + }; + + obj->Configure(args); + CheckConfigReload(obj, "count:poisson"); + + // test label validation + EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0})) + << "Expected error when label < 0 for PoissonRegression"; + + // test ProbToMargin + EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); + + // test PredTransform + HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; + std::vector out_preds = {1, 1.10f, 1.64f, 2.45f, 2.71f}; + obj->PredTransform(&io_preds); + auto& preds = io_preds.HostVector(); + for (int i = 0; i < static_cast(io_preds.Size()); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01f); + } +} + +TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj { + ObjFunction::Create("reg:gamma", &ctx) + }; + + obj->Configure(args); + CheckObjFunction(obj, + {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + {2, 2, 2, 2, 1, 1, 1, 1}, + {1, 1, 1, 1, 1, 1, 1, 1}, + {-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f}, + {2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f}); + CheckObjFunction(obj, + {0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + {2, 2, 2, 2, 1, 1, 1, 1}, + {}, // Empty weight + {-1, -0.809, 0.187, 0.264, 0, 0.09f, 0.59f, 0.63f}, + {2, 1.809, 0.813, 0.735, 1, 0.90f, 0.40f, 0.36f}); +} + +TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj{ObjFunction::Create("reg:gamma", &ctx)}; + + obj->Configure(args); + CheckConfigReload(obj, "reg:gamma"); + + // test label validation + EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {0}, {1}, {0}, {0})) + << "Expected error when label = 0 for GammaRegression"; + EXPECT_ANY_THROW(CheckObjFunction(obj, {-1}, {-1}, {1}, {-1}, {-3})) + << "Expected error when label < 0 for GammaRegression"; + + // test ProbToMargin + EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); + + // test PredTransform + HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; + std::vector out_preds = {1, 1.10f, 1.64f, 2.45f, 2.71f}; + obj->PredTransform(&io_preds); + auto& preds = io_preds.HostVector(); + for (int i = 0; i < static_cast(io_preds.Size()); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01f); + } +} + +TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj{ObjFunction::Create("reg:tweedie", &ctx)}; + + args.emplace_back("tweedie_variance_power", "1.1f"); + obj->Configure(args); + + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + { 1, 1, 1, 1, 1, 1, 1, 1}, + { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, + {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, 0, 0, 0, 1, 1, 1, 1}, + {}, // Empty weight. + { 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f}, + {0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f}); + ASSERT_EQ(obj->DefaultEvalMetric(), std::string{"tweedie-nloglik@1.1"}); +} + +#if defined(__CUDACC__) +TEST(Objective, CPU_vs_CUDA) { + Context ctx = MakeCUDACtx(GPUIDX); + + std::unique_ptr obj{ObjFunction::Create("reg:squarederror", &ctx)}; + linalg::Matrix cpu_out_preds; + linalg::Matrix cuda_out_preds; + + constexpr size_t kRows = 400; + constexpr size_t kCols = 100; + auto pdmat = RandomDataGenerator(kRows, kCols, 0).Seed(0).GenerateDMatrix(); + HostDeviceVector preds; + preds.Resize(kRows); + auto& h_preds = preds.HostVector(); + for (size_t i = 0; i < h_preds.size(); ++i) { + h_preds[i] = static_cast(i); + } + auto& info = pdmat->Info(); + + info.labels.Reshape(kRows); + auto& h_labels = info.labels.Data()->HostVector(); + for (size_t i = 0; i < h_labels.size(); ++i) { + h_labels[i] = 1 / static_cast(i+1); + } + + { + // CPU + ctx = ctx.MakeCPU(); + obj->GetGradient(preds, info, 0, &cpu_out_preds); + } + { + // CUDA + ctx = ctx.MakeCUDA(0); + obj->GetGradient(preds, info, 0, &cuda_out_preds); + } + + auto h_cpu_out = cpu_out_preds.HostView(); + auto h_cuda_out = cuda_out_preds.HostView(); + + float sgrad = 0; + float shess = 0; + for (size_t i = 0; i < kRows; ++i) { + sgrad += std::pow(h_cpu_out(i).GetGrad() - h_cuda_out(i).GetGrad(), 2); + shess += std::pow(h_cpu_out(i).GetHess() - h_cuda_out(i).GetHess(), 2); + } + ASSERT_NEAR(sgrad, 0.0f, kRtEps); + ASSERT_NEAR(shess, 0.0f, kRtEps); +} +#endif + +TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj{ObjFunction::Create("reg:tweedie", &ctx)}; + + obj->Configure(args); + CheckConfigReload(obj, "reg:tweedie"); + + // test label validation + EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {-1}, {1}, {0}, {0})) + << "Expected error when label < 0 for TweedieRegression"; + + // test ProbToMargin + EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.30f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.5f), -0.69f, 0.01f); + EXPECT_NEAR(obj->ProbToMargin(0.9f), -0.10f, 0.01f); + + // test PredTransform + HostDeviceVector io_preds = {0, 0.1f, 0.5f, 0.9f, 1}; + std::vector out_preds = {1, 1.10f, 1.64f, 2.45f, 2.71f}; + obj->PredTransform(&io_preds); + auto& preds = io_preds.HostVector(); + for (int i = 0; i < static_cast(io_preds.Size()); ++i) { + EXPECT_NEAR(preds[i], out_preds[i], 0.01f); + } +} + +// CoxRegression not implemented in GPU code, no need for testing. +#if !defined(__CUDACC__) +TEST(Objective, CoxRegressionGPair) { + Context ctx = MakeCUDACtx(GPUIDX); + std::vector> args; + std::unique_ptr obj{ObjFunction::Create("survival:cox", &ctx)}; + + obj->Configure(args); + CheckObjFunction(obj, + { 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, + { 0, -2, -2, 2, 3, 5, -10, 100}, + { 1, 1, 1, 1, 1, 1, 1, 1}, + { 0, 0, 0, -0.799f, -0.788f, -0.590f, 0.910f, 1.006f}, + { 0, 0, 0, 0.160f, 0.186f, 0.348f, 0.610f, 0.639f}); +} +#endif + +TEST(Objective, DeclareUnifiedTest(AbsoluteError)) { + Context ctx = MakeCUDACtx(GPUIDX); + std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", &ctx)}; + obj->Configure({}); + CheckConfigReload(obj, "reg:absoluteerror"); + + MetaInfo info; + std::vector labels{0.f, 3.f, 2.f, 5.f, 4.f, 7.f}; + info.labels.Reshape(6, 1); + info.labels.Data()->HostVector() = labels; + info.num_row_ = labels.size(); + HostDeviceVector predt{1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + info.weights_.HostVector() = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + + CheckObjFunction(obj, predt.HostVector(), labels, info.weights_.HostVector(), + {1.f, -1.f, 1.f, -1.f, 1.f, -1.f}, info.weights_.HostVector()); + + RegTree tree; + tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + + HostDeviceVector position(labels.size(), 0); + auto& h_position = position.HostVector(); + for (size_t i = 0; i < labels.size(); ++i) { + if (i < labels.size() / 2) { + h_position[i] = 1; // left + } else { + h_position[i] = 2; // right + } + } + + auto& h_predt = predt.HostVector(); + for (size_t i = 0; i < h_predt.size(); ++i) { + h_predt[i] = labels[i] + i; + } + + tree::TrainParam param; + param.Init(Args{}); + auto lr = param.learning_rate; + + obj->UpdateTreeLeaf(position, info, param.learning_rate, predt, 0, &tree); + ASSERT_EQ(tree[1].LeafValue(), -1.0f * lr); + ASSERT_EQ(tree[2].LeafValue(), -4.0f * lr); +} + +TEST(Objective, DeclareUnifiedTest(AbsoluteErrorLeaf)) { + Context ctx = MakeCUDACtx(GPUIDX); + bst_target_t constexpr kTargets = 3, kRows = 16; + std::unique_ptr obj{ObjFunction::Create("reg:absoluteerror", &ctx)}; + obj->Configure({}); + + MetaInfo info; + info.num_row_ = kRows; + info.labels.Reshape(16, kTargets); + HostDeviceVector predt(info.labels.Size()); + + for (bst_target_t t{0}; t < kTargets; ++t) { + auto h_labels = info.labels.HostView().Slice(linalg::All(), t); + std::iota(linalg::begin(h_labels), linalg::end(h_labels), 0); + + auto h_predt = + linalg::MakeTensorView(&ctx, predt.HostSpan(), kRows, kTargets).Slice(linalg::All(), t); + for (size_t i = 0; i < h_predt.Size(); ++i) { + h_predt(i) = h_labels(i) + i; + } + + HostDeviceVector position(h_labels.Size(), 0); + auto& h_position = position.HostVector(); + for (int32_t i = 0; i < 3; ++i) { + h_position[i] = ~i; // negation for sampled nodes. + } + for (size_t i = 3; i < 8; ++i) { + h_position[i] = 3; + } + // empty leaf for node 4 + for (size_t i = 8; i < 13; ++i) { + h_position[i] = 5; + } + for (size_t i = 13; i < h_labels.Size(); ++i) { + h_position[i] = 6; + } + + RegTree tree; + tree.ExpandNode(0, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + tree.ExpandNode(1, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + tree.ExpandNode(2, /*split_index=*/1, 2, true, 0.0f, 2.f, 3.f, 4.f, 2.f, 1.f, 1.f); + ASSERT_EQ(tree.GetNumLeaves(), 4); + + auto empty_leaf = tree[4].LeafValue(); + + tree::TrainParam param; + param.Init(Args{}); + auto lr = param.learning_rate; + + obj->UpdateTreeLeaf(position, info, lr, predt, t, &tree); + ASSERT_EQ(tree[3].LeafValue(), -5.0f * lr); + ASSERT_EQ(tree[4].LeafValue(), empty_leaf * lr); + ASSERT_EQ(tree[5].LeafValue(), -10.0f * lr); + ASSERT_EQ(tree[6].LeafValue(), -14.0f * lr); + } +} + +TEST(Adaptive, DeclareUnifiedTest(MissingLeaf)) { + std::vector missing{1, 3}; + + std::vector h_nidx = {2, 4, 5}; + std::vector h_nptr = {0, 4, 8, 16}; + + obj::detail::FillMissingLeaf(missing, &h_nidx, &h_nptr); + + ASSERT_EQ(h_nidx[0], missing[0]); + ASSERT_EQ(h_nidx[2], missing[1]); + ASSERT_EQ(h_nidx[1], 2); + ASSERT_EQ(h_nidx[3], 4); + ASSERT_EQ(h_nidx[4], 5); + + ASSERT_EQ(h_nptr[0], 0); + ASSERT_EQ(h_nptr[1], 0); // empty + ASSERT_EQ(h_nptr[2], 4); + ASSERT_EQ(h_nptr[3], 4); // empty + ASSERT_EQ(h_nptr[4], 8); + ASSERT_EQ(h_nptr[5], 16); +} +} // namespace xgboost diff --git a/tests/cpp/objective/test_regression_obj_gpu.cu b/tests/cpp/objective/test_regression_obj_gpu.cu index 38f29b8a8..746468f71 100644 --- a/tests/cpp/objective/test_regression_obj_gpu.cu +++ b/tests/cpp/objective/test_regression_obj_gpu.cu @@ -3,4 +3,4 @@ */ // Dummy file to keep the CUDA tests. -#include "test_regression_obj.cc" +#include "test_regression_obj_cpu.cc" diff --git a/tests/cpp/plugin/test_sycl_multiclass_obj.cc b/tests/cpp/plugin/test_sycl_multiclass_obj.cc new file mode 100644 index 000000000..d809ecad3 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_multiclass_obj.cc @@ -0,0 +1,28 @@ +/*! + * Copyright 2018-2023 XGBoost contributors + */ +#include +#include + +#include "../objective/test_multiclass_obj.h" + +namespace xgboost { + +TEST(SyclObjective, SoftmaxMultiClassObjGPair) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSoftmaxMultiClassObjGPair(&ctx); +} + +TEST(SyclObjective, SoftmaxMultiClassBasic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSoftmaxMultiClassObjGPair(&ctx); +} + +TEST(SyclObjective, SoftprobMultiClassBasic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSoftprobMultiClassBasic(&ctx); +} +} // namespace xgboost diff --git a/tests/cpp/plugin/test_sycl_regression_obj.cc b/tests/cpp/plugin/test_sycl_regression_obj.cc new file mode 100644 index 000000000..66b4ea508 --- /dev/null +++ b/tests/cpp/plugin/test_sycl_regression_obj.cc @@ -0,0 +1,99 @@ +/*! + * Copyright 2017-2019 XGBoost contributors + */ +#include +#include +#include + +#include "../helpers.h" +#include "../objective/test_regression_obj.h" + +namespace xgboost { + +TEST(SyclObjective, LinearRegressionGPair) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestLinearRegressionGPair(&ctx); +} + +TEST(SyclObjective, SquaredLog) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestSquaredLog(&ctx); +} + +TEST(SyclObjective, LogisticRegressionGPair) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestLogisticRegressionGPair(&ctx); +} + +TEST(SyclObjective, LogisticRegressionBasic) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + + TestLogisticRegressionBasic(&ctx); +} + +TEST(SyclObjective, LogisticRawGPair) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + TestsLogisticRawGPair(&ctx); +} + +TEST(SyclObjective, CPUvsSycl) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"device", "sycl"}}); + ObjFunction * obj_sycl = + ObjFunction::Create("reg:squarederror_sycl", &ctx); + + ctx = ctx.MakeCPU(); + ObjFunction * obj_cpu = + ObjFunction::Create("reg:squarederror", &ctx); + + linalg::Matrix cpu_out_preds; + linalg::Matrix sycl_out_preds; + + constexpr size_t kRows = 400; + constexpr size_t kCols = 100; + auto pdmat = RandomDataGenerator(kRows, kCols, 0).Seed(0).GenerateDMatrix(); + HostDeviceVector preds; + preds.Resize(kRows); + auto& h_preds = preds.HostVector(); + for (size_t i = 0; i < h_preds.size(); ++i) { + h_preds[i] = static_cast(i); + } + auto& info = pdmat->Info(); + + info.labels.Reshape(kRows, 1); + auto& h_labels = info.labels.Data()->HostVector(); + for (size_t i = 0; i < h_labels.size(); ++i) { + h_labels[i] = 1 / static_cast(i+1); + } + + { + // CPU + obj_cpu->GetGradient(preds, info, 0, &cpu_out_preds); + } + { + // sycl + obj_sycl->GetGradient(preds, info, 0, &sycl_out_preds); + } + + auto h_cpu_out = cpu_out_preds.HostView(); + auto h_sycl_out = sycl_out_preds.HostView(); + + float sgrad = 0; + float shess = 0; + for (size_t i = 0; i < kRows; ++i) { + sgrad += std::pow(h_cpu_out(i).GetGrad() - h_sycl_out(i).GetGrad(), 2); + shess += std::pow(h_cpu_out(i).GetHess() - h_sycl_out(i).GetHess(), 2); + } + ASSERT_NEAR(sgrad, 0.0f, kRtEps); + ASSERT_NEAR(shess, 0.0f, kRtEps); + + delete obj_cpu; + delete obj_sycl; +} + +} // namespace xgboost