Initial support for multi-target tree. (#8616)

* Implement multi-target for hist.

- Add new hist tree builder.
- Move data fetchers for tests.
- Dispatch function calls in gbm base on the tree type.
This commit is contained in:
Jiaming Yuan
2023-03-22 23:49:56 +08:00
committed by GitHub
parent ea04d4c46c
commit 151882dd26
34 changed files with 856 additions and 389 deletions

View File

@@ -116,7 +116,7 @@ def test_with_mq2008(objective, metric) -> None:
x_valid,
y_valid,
qid_valid,
) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
) = tm.data.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
if metric.find("map") != -1 or objective.find("map") != -1:
y_train[y_train <= 1] = 0.0

View File

@@ -32,6 +32,19 @@ def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
return result
class TestGPUUpdatersMulti:
@given(
hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, max_examples=50, print_blob=True)
def test_hist(self, param, num_rounds, dataset):
param["tree_method"] = "gpu_hist"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
class TestGPUUpdaters:
cputest = test_up.TestTreeMethod()
@@ -101,7 +114,7 @@ class TestGPUUpdaters:
) -> None:
cat_parameters.update(hist_parameters)
dataset = tm.TestDataset(
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
"ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse"
)
cat_parameters["tree_method"] = "gpu_hist"
results = train_result(cat_parameters, dataset.get_dmat(), 16)