Calculate base_score based on input labels for mae. (#8107)

Fit an intercept as base score for abs loss.
This commit is contained in:
Jiaming Yuan
2022-09-20 20:53:54 +08:00
committed by GitHub
parent 4f42aa5f12
commit fffb1fca52
42 changed files with 999 additions and 343 deletions

View File

@@ -1537,13 +1537,56 @@ class TestWithDask:
@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.gtest
def test_quantile_same_on_all_workers(self) -> None:
self.run_quantile('SameOnAllWorkers')
self.run_quantile("SameOnAllWorkers")
def test_adaptive(self) -> None:
def get_score(config: Dict) -> float:
return float(config["learner"]["learner_model_param"]["base_score"])
def local_test(rabit_args: List[bytes], worker_id: int) -> bool:
with xgb.dask.RabitContext(rabit_args):
if worker_id == 0:
y = np.array([0.0, 0.0, 0.0])
x = np.array([[0.0]] * 3)
else:
y = np.array([1000.0])
x = np.array(
[
[0.0],
]
)
Xy = xgb.DMatrix(x, y)
booster = xgb.train(
{"tree_method": "hist", "objective": "reg:absoluteerror"},
Xy,
num_boost_round=1,
)
config = json.loads(booster.save_config())
base_score = get_score(config)
assert base_score == 250.0
return True
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
rabit_args = client.sync(
xgb.dask._get_rabit_args, len(workers), None, client
)
futures = []
for i, _ in enumerate(workers):
f = client.submit(local_test, rabit_args, i)
futures.append(f)
results = client.gather(futures)
assert all(results)
def test_n_workers(self) -> None:
with LocalCluster(n_workers=2, dashboard_address=":0") as cluster:
with Client(cluster) as client:
workers = _get_client_workers(client)
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
dX = client.submit(da.from_array, X, workers=[workers[0]]).result()
dy = client.submit(da.from_array, y, workers=[workers[0]]).result()