Send default configuration from metric to objective. (#8760)
This commit is contained in:
parent
5f76edd296
commit
199c421d60
@ -55,6 +55,11 @@ class ObjFunction : public Configurable {
|
|||||||
|
|
||||||
/*! \return the default evaluation metric for the objective */
|
/*! \return the default evaluation metric for the objective */
|
||||||
virtual const char* DefaultEvalMetric() const = 0;
|
virtual const char* DefaultEvalMetric() const = 0;
|
||||||
|
/**
|
||||||
|
* \brief Return the configuration for the default metric.
|
||||||
|
*/
|
||||||
|
virtual Json DefaultMetricConfig() const { return Json{Null{}}; }
|
||||||
|
|
||||||
// the following functions are optional, most of time default implementation is good enough
|
// the following functions are optional, most of time default implementation is good enough
|
||||||
/*!
|
/*!
|
||||||
* \brief transform prediction values, this is only called when Prediction is called
|
* \brief transform prediction values, this is only called when Prediction is called
|
||||||
|
|||||||
@ -520,6 +520,7 @@ class LearnerConfiguration : public Learner {
|
|||||||
|
|
||||||
auto const& objective_fn = learner_parameters.at("objective");
|
auto const& objective_fn = learner_parameters.at("objective");
|
||||||
if (!obj_) {
|
if (!obj_) {
|
||||||
|
CHECK_EQ(get<String const>(objective_fn["name"]), tparam_.objective);
|
||||||
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
|
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
|
||||||
}
|
}
|
||||||
obj_->LoadConfig(objective_fn);
|
obj_->LoadConfig(objective_fn);
|
||||||
@ -1311,8 +1312,10 @@ class LearnerImpl : public LearnerIO {
|
|||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
os.precision(std::numeric_limits<double>::max_digits10);
|
os.precision(std::numeric_limits<double>::max_digits10);
|
||||||
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
os << '[' << iter << ']' << std::setiosflags(std::ios::fixed);
|
||||||
if (metrics_.size() == 0 && tparam_.disable_default_eval_metric <= 0) {
|
if (metrics_.empty() && tparam_.disable_default_eval_metric <= 0) {
|
||||||
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &ctx_));
|
metrics_.emplace_back(Metric::Create(obj_->DefaultEvalMetric(), &ctx_));
|
||||||
|
auto config = obj_->DefaultMetricConfig();
|
||||||
|
metrics_.back()->LoadConfig(config);
|
||||||
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
metrics_.back()->Configure({cfg_.begin(), cfg_.end()});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -134,6 +134,12 @@ class AFTObj : public ObjFunction {
|
|||||||
void LoadConfig(Json const& in) override {
|
void LoadConfig(Json const& in) override {
|
||||||
FromJson(in["aft_loss_param"], ¶m_);
|
FromJson(in["aft_loss_param"], ¶m_);
|
||||||
}
|
}
|
||||||
|
Json DefaultMetricConfig() const override {
|
||||||
|
Json config{Object{}};
|
||||||
|
config["name"] = String{this->DefaultEvalMetric()};
|
||||||
|
config["aft_loss_param"] = ToJson(param_);
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
AFTParam param_;
|
AFTParam param_;
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -10,16 +11,56 @@ from xgboost import testing as tm
|
|||||||
dpath = tm.data_dir(__file__)
|
dpath = tm.data_dir(__file__)
|
||||||
|
|
||||||
|
|
||||||
def test_aft_survival_toy_data():
|
@pytest.fixture(scope="module")
|
||||||
# See demo/aft_survival/aft_survival_viz_demo.py
|
def toy_data() -> Tuple[xgb.DMatrix, np.ndarray, np.ndarray]:
|
||||||
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
||||||
INF = np.inf
|
INF = np.inf
|
||||||
y_lower = np.array([ 10, 15, -INF, 30, 100])
|
y_lower = np.array([10, 15, -INF, 30, 100])
|
||||||
y_upper = np.array([INF, INF, 20, 50, INF])
|
y_upper = np.array([INF, INF, 20, 50, INF])
|
||||||
|
|
||||||
dmat = xgb.DMatrix(X)
|
dmat = xgb.DMatrix(X)
|
||||||
dmat.set_float_info('label_lower_bound', y_lower)
|
dmat.set_float_info("label_lower_bound", y_lower)
|
||||||
dmat.set_float_info('label_upper_bound', y_upper)
|
dmat.set_float_info("label_upper_bound", y_upper)
|
||||||
|
return dmat, y_lower, y_upper
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_metric(toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]) -> None:
|
||||||
|
Xy, y_lower, y_upper = toy_data
|
||||||
|
|
||||||
|
def run(evals: Optional[list]) -> None:
|
||||||
|
# test with or without actual evaluation.
|
||||||
|
booster = xgb.train(
|
||||||
|
{"objective": "survival:aft", "aft_loss_distribution": "extreme"},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
evals=evals,
|
||||||
|
)
|
||||||
|
config = json.loads(booster.save_config())
|
||||||
|
metrics = config["learner"]["metrics"]
|
||||||
|
assert len(metrics) == 1
|
||||||
|
assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "extreme"
|
||||||
|
|
||||||
|
booster = xgb.train(
|
||||||
|
{"objective": "survival:aft"},
|
||||||
|
Xy,
|
||||||
|
num_boost_round=1,
|
||||||
|
evals=evals,
|
||||||
|
)
|
||||||
|
config = json.loads(booster.save_config())
|
||||||
|
metrics = config["learner"]["metrics"]
|
||||||
|
assert len(metrics) == 1
|
||||||
|
assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "normal"
|
||||||
|
|
||||||
|
run([(Xy, "Train")])
|
||||||
|
run(None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_aft_survival_toy_data(
|
||||||
|
toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]
|
||||||
|
) -> None:
|
||||||
|
# See demo/aft_survival/aft_survival_viz_demo.py
|
||||||
|
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
||||||
|
dmat, y_lower, y_upper = toy_data
|
||||||
|
|
||||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
||||||
# the corresponding predicted label (y_pred)
|
# the corresponding predicted label (y_pred)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user