Send default configuration from metric to objective. (#8760)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -10,16 +11,56 @@ from xgboost import testing as tm
|
||||
dpath = tm.data_dir(__file__)
|
||||
|
||||
|
||||
def test_aft_survival_toy_data():
|
||||
# See demo/aft_survival/aft_survival_viz_demo.py
|
||||
@pytest.fixture(scope="module")
|
||||
def toy_data() -> Tuple[xgb.DMatrix, np.ndarray, np.ndarray]:
|
||||
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
||||
INF = np.inf
|
||||
y_lower = np.array([ 10, 15, -INF, 30, 100])
|
||||
y_upper = np.array([INF, INF, 20, 50, INF])
|
||||
y_lower = np.array([10, 15, -INF, 30, 100])
|
||||
y_upper = np.array([INF, INF, 20, 50, INF])
|
||||
|
||||
dmat = xgb.DMatrix(X)
|
||||
dmat.set_float_info('label_lower_bound', y_lower)
|
||||
dmat.set_float_info('label_upper_bound', y_upper)
|
||||
dmat.set_float_info("label_lower_bound", y_lower)
|
||||
dmat.set_float_info("label_upper_bound", y_upper)
|
||||
return dmat, y_lower, y_upper
|
||||
|
||||
|
||||
def test_default_metric(toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]) -> None:
|
||||
Xy, y_lower, y_upper = toy_data
|
||||
|
||||
def run(evals: Optional[list]) -> None:
|
||||
# test with or without actual evaluation.
|
||||
booster = xgb.train(
|
||||
{"objective": "survival:aft", "aft_loss_distribution": "extreme"},
|
||||
Xy,
|
||||
num_boost_round=1,
|
||||
evals=evals,
|
||||
)
|
||||
config = json.loads(booster.save_config())
|
||||
metrics = config["learner"]["metrics"]
|
||||
assert len(metrics) == 1
|
||||
assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "extreme"
|
||||
|
||||
booster = xgb.train(
|
||||
{"objective": "survival:aft"},
|
||||
Xy,
|
||||
num_boost_round=1,
|
||||
evals=evals,
|
||||
)
|
||||
config = json.loads(booster.save_config())
|
||||
metrics = config["learner"]["metrics"]
|
||||
assert len(metrics) == 1
|
||||
assert metrics[0]["aft_loss_param"]["aft_loss_distribution"] == "normal"
|
||||
|
||||
run([(Xy, "Train")])
|
||||
run(None)
|
||||
|
||||
|
||||
def test_aft_survival_toy_data(
|
||||
toy_data: Tuple[xgb.DMatrix, np.ndarray, np.ndarray]
|
||||
) -> None:
|
||||
# See demo/aft_survival/aft_survival_viz_demo.py
|
||||
X = np.array([1, 2, 3, 4, 5]).reshape((-1, 1))
|
||||
dmat, y_lower, y_upper = toy_data
|
||||
|
||||
# "Accuracy" = the number of data points whose ranged label (y_lower, y_upper) includes
|
||||
# the corresponding predicted label (y_pred)
|
||||
|
||||
Reference in New Issue
Block a user