Merge branch 'master' into sync-condition-2023May15

This commit is contained in:
amdsc21
2023-05-23 01:07:50 +02:00
8 changed files with 66 additions and 51 deletions

View File

@@ -603,26 +603,6 @@ sparse_datasets_strategy = strategies.sampled_from(
]
)
_unweighted_datasets_strategy = strategies.sampled_from(
[
TestDataset(
"calif_housing", get_california_housing, "reg:squarederror", "rmse"
),
TestDataset(
"calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae"
),
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
TestDataset(
"empty",
lambda: (np.empty((0, 100)), np.empty(0)),
"reg:squarederror",
"rmse",
),
]
)
def make_datasets_with_margin(
unweighted_strategy: strategies.SearchStrategy,
@@ -664,7 +644,28 @@ def make_datasets_with_margin(
# A strategy for drawing from a set of example datasets. May add random weights to the
# dataset
dataset_strategy = make_datasets_with_margin(_unweighted_datasets_strategy)()
@memory.cache
def make_dataset_strategy() -> Callable:
_unweighted_datasets_strategy = strategies.sampled_from(
[
TestDataset(
"calif_housing", get_california_housing, "reg:squarederror", "rmse"
),
TestDataset(
"calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae"
),
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
TestDataset(
"empty",
lambda: (np.empty((0, 100)), np.empty(0)),
"reg:squarederror",
"rmse",
),
]
)
return make_datasets_with_margin(_unweighted_datasets_strategy)()
_unweighted_multi_datasets_strategy = strategies.sampled_from(