merge latest, Jan 12 2024

This commit is contained in:
Hui Liu
2024-01-12 09:57:11 -08:00
251 changed files with 9023 additions and 5012 deletions

View File

@@ -18,7 +18,11 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
namespace xgboost {
// implement factory functions
ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) {
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
std::string obj_name = name;
if (ctx->IsSycl()) {
obj_name = GetSyclImplementationName(obj_name);
}
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(obj_name);
if (e == nullptr) {
std::stringstream ss;
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
@@ -32,6 +36,22 @@ ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) {
return pobj;
}
/* If the objective function has sycl-specific implementation,
* returns the specific implementation name.
* Otherwise return the orginal name without modifications.
*/
std::string ObjFunction::GetSyclImplementationName(const std::string& name) {
const std::string sycl_postfix = "_sycl";
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + sycl_postfix);
if (e != nullptr) {
// Function has specific sycl implementation
return name + sycl_postfix;
} else {
// Function hasn't specific sycl implementation
return name;
}
}
void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor<float, 1>* base_score) const {
CHECK(base_score);
base_score->Reshape(1);