[dask] Allow empty data matrix in AFT survival (#6379)
* [dask] Allow empty data matrix in AFT survival * Add unit test
This commit is contained in:
parent
5a33c2f3a0
commit
e5193c21a1
@ -206,10 +206,6 @@ struct EvalEWiseSurvivalBase : public Metric {
|
|||||||
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
bst_float Eval(const HostDeviceVector<bst_float>& preds,
|
||||||
const MetaInfo& info,
|
const MetaInfo& info,
|
||||||
bool distributed) override {
|
bool distributed) override {
|
||||||
CHECK_NE(info.labels_lower_bound_.Size(), 0U)
|
|
||||||
<< "labels_lower_bound cannot be empty";
|
|
||||||
CHECK_NE(info.labels_upper_bound_.Size(), 0U)
|
|
||||||
<< "labels_upper_bound cannot be empty";
|
|
||||||
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
CHECK_EQ(preds.Size(), info.labels_lower_bound_.Size());
|
||||||
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
CHECK_EQ(preds.Size(), info.labels_upper_bound_.Size());
|
||||||
|
|
||||||
|
|||||||
@ -52,6 +52,16 @@ def test_aft_survival_toy_data():
|
|||||||
for tree in model_json:
|
for tree in model_json:
|
||||||
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
|
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
|
||||||
|
|
||||||
|
|
||||||
|
def test_aft_empty_dmatrix():
|
||||||
|
X = np.array([]).reshape((0, 2))
|
||||||
|
y_lower, y_upper = np.array([]), np.array([])
|
||||||
|
dtrain = xgb.DMatrix(X)
|
||||||
|
dtrain.set_info(label_lower_bound=y_lower, label_upper_bound=y_upper)
|
||||||
|
bst = xgb.train({'objective': 'survival:aft', 'tree_method': 'hist'},
|
||||||
|
dtrain, num_boost_round=2, evals=[(dtrain, 'train')])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_aft_survival_demo_data():
|
def test_aft_survival_demo_data():
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|||||||
@ -594,7 +594,6 @@ def test_predict_with_meta(client):
|
|||||||
|
|
||||||
|
|
||||||
def run_aft_survival(client, dmatrix_t):
|
def run_aft_survival(client, dmatrix_t):
|
||||||
# survival doesn't handle empty dataset well.
|
|
||||||
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
||||||
'veterans_lung_cancer.csv'))
|
'veterans_lung_cancer.csv'))
|
||||||
y_lower_bound = df['Survival_label_lower_bound']
|
y_lower_bound = df['Survival_label_lower_bound']
|
||||||
@ -632,7 +631,7 @@ def run_aft_survival(client, dmatrix_t):
|
|||||||
|
|
||||||
|
|
||||||
def test_aft_survival():
|
def test_aft_survival():
|
||||||
with LocalCluster(n_workers=1) as cluster:
|
with LocalCluster(n_workers=kWorkers) as cluster:
|
||||||
with Client(cluster) as client:
|
with Client(cluster) as client:
|
||||||
run_aft_survival(client, DaskDMatrix)
|
run_aft_survival(client, DaskDMatrix)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user