Sycl implementation for objective functions (#9846)

---------

Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
Dmitry Razdoburdin
2023-12-12 07:41:50 +01:00
committed by GitHub
parent ddab49a8be
commit 43897b8296
19 changed files with 1129 additions and 423 deletions

View File

@@ -1,18 +1,18 @@
/*!
* Copyright 2018-2019 XGBoost contributors
* Copyright 2018-2023 XGBoost contributors
*/
#include <xgboost/objective.h>
#include <xgboost/context.h>
#include "../../src/common/common.h"
#include "../helpers.h"
#include "test_multiclass_obj.h"
namespace xgboost {
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
Context ctx = MakeCUDACtx(GPUIDX);
void TestSoftmaxMultiClassObjGPair(const Context* ctx) {
std::vector<std::pair<std::string, std::string>> args {{"num_class", "3"}};
std::unique_ptr<ObjFunction> obj {
ObjFunction::Create("multi:softmax", &ctx)
ObjFunction::Create("multi:softmax", ctx)
};
obj->Configure(args);
@@ -35,12 +35,11 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
ASSERT_NO_THROW(obj->DefaultEvalMetric());
}
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) {
auto ctx = MakeCUDACtx(GPUIDX);
void TestSoftmaxMultiClassBasic(const Context* ctx) {
std::vector<std::pair<std::string, std::string>> args{
std::pair<std::string, std::string>("num_class", "3")};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("multi:softmax", &ctx)};
std::unique_ptr<ObjFunction> obj{ObjFunction::Create("multi:softmax", ctx)};
obj->Configure(args);
CheckConfigReload(obj, "multi:softmax");
@@ -56,13 +55,12 @@ TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) {
}
}
TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) {
Context ctx = MakeCUDACtx(GPUIDX);
void TestSoftprobMultiClassBasic(const Context* ctx) {
std::vector<std::pair<std::string, std::string>> args {
std::pair<std::string, std::string>("num_class", "3")};
std::unique_ptr<ObjFunction> obj {
ObjFunction::Create("multi:softprob", &ctx)
ObjFunction::Create("multi:softprob", ctx)
};
obj->Configure(args);
CheckConfigReload(obj, "multi:softprob");
@@ -77,4 +75,5 @@ TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) {
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
}
}
} // namespace xgboost