Initial support for quantile loss. (#8750)
- Add support for Python. - Add objective.
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
// Copyright by Contributors
|
||||
/**
|
||||
* Copyright 2016-2023 by XGBoost contributors
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
#include <xgboost/context.h>
|
||||
#include <xgboost/objective.h>
|
||||
@@ -25,11 +27,14 @@ TEST(Objective, PredTransform) {
|
||||
tparam.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
|
||||
size_t n = 100;
|
||||
|
||||
for (const auto &entry :
|
||||
::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{
|
||||
xgboost::ObjFunction::Create(entry->name, &tparam)};
|
||||
obj->Configure(Args{{"num_class", "2"}});
|
||||
for (const auto& entry : ::dmlc::Registry<::xgboost::ObjFunctionReg>::List()) {
|
||||
std::unique_ptr<xgboost::ObjFunction> obj{xgboost::ObjFunction::Create(entry->name, &tparam)};
|
||||
if (entry->name.find("multi") != std::string::npos) {
|
||||
obj->Configure(Args{{"num_class", "2"}});
|
||||
}
|
||||
if (entry->name.find("quantile") != std::string::npos) {
|
||||
obj->Configure(Args{{"quantile_alpha", "0.5"}});
|
||||
}
|
||||
HostDeviceVector<float> predts;
|
||||
predts.Resize(n, 3.14f); // prediction is performed on host.
|
||||
ASSERT_FALSE(predts.DeviceCanRead());
|
||||
|
||||
Reference in New Issue
Block a user