[dask] Fix prediction on DaskDMatrix with multiple meta data. (#6333)

* Unify the meta handling methods.
This commit is contained in:
Jiaming Yuan
2020-11-02 19:18:44 -05:00
committed by GitHub
parent 5a7b3592ed
commit 7756192906
2 changed files with 85 additions and 58 deletions

View File

@@ -566,6 +566,28 @@ def test_predict():
assert shap.shape[1] == kCols + 1
def test_predict_with_meta(client):
X, y, w = generate_array(with_weights=True)
partition_size = 20
margin = da.random.random(kRows, partition_size) + 1e4
dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
booster = xgb.dask.train(
client, {}, dtrain, num_boost_round=4)['booster']
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert prediction.ndim == 1
assert prediction.shape[0] == kRows
prediction = client.compute(prediction).result()
assert np.all(prediction > 1e3)
m = xgb.DMatrix(X.compute())
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
single = booster.predict(m) # Make sure the ordering is correct.
assert np.all(prediction == single)
def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well.
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',