Sycl implementation for objective functions (#9846)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
committed by
GitHub
parent
ddab49a8be
commit
43897b8296
@@ -18,7 +18,11 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
namespace xgboost {
|
||||
// implement factory functions
|
||||
ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||
std::string obj_name = name;
|
||||
if (ctx->IsSycl()) {
|
||||
obj_name = GetSyclImplementationName(obj_name);
|
||||
}
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name);
|
||||
if (e == nullptr) {
|
||||
std::stringstream ss;
|
||||
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
|
||||
@@ -32,6 +36,22 @@ ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) {
|
||||
return pobj;
|
||||
}
|
||||
|
||||
/* If the objective function has sycl-specific implementation,
|
||||
* returns the specific implementation name.
|
||||
* Otherwise return the orginal name without modifications.
|
||||
*/
|
||||
std::string ObjFunction::GetSyclImplementationName(const std::string& name) {
|
||||
const std::string sycl_postfix = "_sycl";
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + sycl_postfix);
|
||||
if (e != nullptr) {
|
||||
// Function has specific sycl implementation
|
||||
return name + sycl_postfix;
|
||||
} else {
|
||||
// Function hasn't specific sycl implementation
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor<float, 1>* base_score) const {
|
||||
CHECK(base_score);
|
||||
base_score->Reshape(1);
|
||||
|
||||
Reference in New Issue
Block a user