[dask] Fix prediction on DaskDMatrix with multiple meta data. (#6333)
* Unify the meta handling methods.
This commit is contained in:
@@ -566,6 +566,28 @@ def test_predict():
|
||||
assert shap.shape[1] == kCols + 1
|
||||
|
||||
|
||||
def test_predict_with_meta(client):
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
partition_size = 20
|
||||
margin = da.random.random(kRows, partition_size) + 1e4
|
||||
|
||||
dtrain = DaskDMatrix(client, X, y, weight=w, base_margin=margin)
|
||||
booster = xgb.dask.train(
|
||||
client, {}, dtrain, num_boost_round=4)['booster']
|
||||
|
||||
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
|
||||
assert prediction.ndim == 1
|
||||
assert prediction.shape[0] == kRows
|
||||
|
||||
prediction = client.compute(prediction).result()
|
||||
assert np.all(prediction > 1e3)
|
||||
|
||||
m = xgb.DMatrix(X.compute())
|
||||
m.set_info(label=y.compute(), weight=w.compute(), base_margin=margin.compute())
|
||||
single = booster.predict(m) # Make sure the ordering is correct.
|
||||
assert np.all(prediction == single)
|
||||
|
||||
|
||||
def run_aft_survival(client, dmatrix_t):
|
||||
# survival doesn't handle empty dataset well.
|
||||
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
||||
|
||||
Reference in New Issue
Block a user