Multi-target support for L1 error. (#8652)

- Add matrix support to the median function.
- Iterate through each target for quantile computation.
This commit is contained in:
Jiaming Yuan
2023-01-11 05:51:14 +08:00
committed by GitHub
parent badeff1d74
commit cfa994d57f
19 changed files with 430 additions and 215 deletions

View File

@@ -317,13 +317,13 @@ class TestDataset:
enable_categorical=True,
)
def get_device_dmat(self) -> xgb.DeviceQuantileDMatrix:
def get_device_dmat(self) -> xgb.QuantileDMatrix:
import cupy as cp
w = None if self.w is None else cp.array(self.w)
X = cp.array(self.X, dtype=np.float32)
y = cp.array(self.y, dtype=np.float32)
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
return xgb.QuantileDMatrix(X, y, weight=w, base_margin=self.margin)
def get_external_dmat(self) -> xgb.DMatrix:
n_samples = self.X.shape[0]
@@ -726,10 +726,16 @@ _unweighted_datasets_strategy = strategies.sampled_from(
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
TestDataset(
"mtreg",
lambda: datasets.make_regression(n_samples=128, n_targets=3),
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
"reg:squarederror",
"rmse",
),
TestDataset(
"mtreg-l1",
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
"reg:absoluteerror",
"mae",
),
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
TestDataset(
@@ -753,7 +759,7 @@ def _dataset_weight_margin(draw: Callable) -> TestDataset:
num_class = 1
if data.objective == "multi:softmax":
num_class = int(np.max(data.y) + 1)
elif data.name == "mtreg":
elif data.name.startswith("mtreg"):
num_class = data.y.shape[1]
data.margin = draw(