[dask] Support more meta data on functional interface. (#6132)

* Add base_margin, label_(lower|upper)_bound.
* Test survival training with dask.
This commit is contained in:
Jiaming Yuan 2020-09-21 16:56:37 +08:00 committed by GitHub
parent 7065779afa
commit 33d80ffad0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 48 deletions

View File

@ -320,7 +320,8 @@ class DataIter:
def data_handle(data, label=None, weight=None, base_margin=None, def data_handle(data, label=None, weight=None, base_margin=None,
group=None, group=None,
label_lower_bound=None, label_upper_bound=None, label_lower_bound=None, label_upper_bound=None,
feature_names=None, feature_types=None): feature_names=None, feature_types=None,
feature_weights=None):
from .data import dispatch_device_quantile_dmatrix_set_data from .data import dispatch_device_quantile_dmatrix_set_data
from .data import _device_quantile_transform from .data import _device_quantile_transform
data, feature_names, feature_types = _device_quantile_transform( data, feature_names, feature_types = _device_quantile_transform(
@ -333,7 +334,8 @@ class DataIter:
label_lower_bound=label_lower_bound, label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound, label_upper_bound=label_upper_bound,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types) feature_types=feature_types,
feature_weights=feature_weights)
try: try:
# Differ the exception in order to return 0 and stop the iteration. # Differ the exception in order to return 0 and stop the iteration.
# Exception inside a ctype callback function has no effect except # Exception inside a ctype callback function has no effect except

View File

@ -178,6 +178,12 @@ class DaskDMatrix:
to be present as a missing value. If None, defaults to np.nan. to be present as a missing value. If None, defaults to np.nan.
weight : dask.array.Array/dask.dataframe.DataFrame weight : dask.array.Array/dask.dataframe.DataFrame
Weight for each instance. Weight for each instance.
base_margin : dask.array.Array/dask.dataframe.DataFrame
Global bias for each instance.
label_lower_bound : dask.array.Array/dask.dataframe.DataFrame
Upper bound for survival training.
label_upper_bound : dask.array.Array/dask.dataframe.DataFrame
Lower bound for survival training.
feature_names : list, optional feature_names : list, optional
Set names for features. Set names for features.
feature_types : list, optional feature_types : list, optional
@ -191,6 +197,9 @@ class DaskDMatrix:
label=None, label=None,
missing=None, missing=None,
weight=None, weight=None,
base_margin=None,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None, feature_names=None,
feature_types=None): feature_types=None):
_assert_dask_support() _assert_dask_support()
@ -216,12 +225,17 @@ class DaskDMatrix:
self.is_quantile = False self.is_quantile = False
self._init = client.sync(self.map_local_data, self._init = client.sync(self.map_local_data,
client, data, label, weight) client, data, label=label, weights=weight,
base_margin=base_margin,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
def __await__(self): def __await__(self):
return self._init.__await__() return self._init.__await__()
async def map_local_data(self, client, data, label=None, weights=None): async def map_local_data(self, client, data, label=None, weights=None,
base_margin=None,
label_lower_bound=None, label_upper_bound=None):
'''Obtain references to local data.''' '''Obtain references to local data.'''
def inconsistent(left, left_name, right, right_name): def inconsistent(left, left_name, right, right_name):
@ -241,10 +255,10 @@ class DaskDMatrix:
' chunks=(partition_size, X.shape[1])' ' chunks=(partition_size, X.shape[1])'
data = data.persist() data = data.persist()
if label is not None: for meta in [label, weights, base_margin, label_lower_bound,
label = label.persist() label_upper_bound]:
if weights is not None: if meta is not None:
weights = weights.persist() meta = meta.persist()
# Breaking data into partitions, a trick borrowed from dask_xgboost. # Breaking data into partitions, a trick borrowed from dask_xgboost.
# `to_delayed` downgrades high-level objects into numpy or pandas # `to_delayed` downgrades high-level objects into numpy or pandas
@ -254,29 +268,37 @@ class DaskDMatrix:
check_columns(X_parts) check_columns(X_parts)
X_parts = X_parts.flatten().tolist() X_parts = X_parts.flatten().tolist()
if label is not None: def flatten_meta(meta):
y_parts = label.to_delayed() if meta is not None:
if isinstance(y_parts, numpy.ndarray): meta_parts = meta.to_delayed()
check_columns(y_parts) if isinstance(meta_parts, numpy.ndarray):
y_parts = y_parts.flatten().tolist() check_columns(meta_parts)
if weights is not None: meta_parts = meta_parts.flatten().tolist()
w_parts = weights.to_delayed() return meta_parts
if isinstance(w_parts, numpy.ndarray): return None
check_columns(w_parts)
w_parts = w_parts.flatten().tolist() y_parts = flatten_meta(label)
w_parts = flatten_meta(weights)
margin_parts = flatten_meta(base_margin)
ll_parts = flatten_meta(label_lower_bound)
lu_parts = flatten_meta(label_upper_bound)
parts = [X_parts] parts = [X_parts]
meta_names = [] meta_names = []
if label is not None:
assert len(X_parts) == len( def append_meta(m_parts, name: str):
y_parts), inconsistent(X_parts, 'X', y_parts, 'labels') if m_parts is not None:
parts.append(y_parts) assert len(X_parts) == len(
meta_names.append('labels') m_parts), inconsistent(X_parts, 'X', m_parts, name)
if weights is not None: parts.append(m_parts)
assert len(X_parts) == len( meta_names.append(name)
w_parts), inconsistent(X_parts, 'X', w_parts, 'weights')
parts.append(w_parts) append_meta(y_parts, 'labels')
meta_names.append('weights') append_meta(w_parts, 'weights')
append_meta(margin_parts, 'base_margin')
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
parts = list(map(delayed, zip(*parts))) parts = list(map(delayed, zip(*parts)))
parts = client.compute(parts) parts = client.compute(parts)
@ -339,6 +361,9 @@ def _get_worker_parts(worker_map, meta_names, worker):
data = None data = None
labels = None labels = None
weights = None weights = None
base_margin = None
label_lower_bound = None
label_upper_bound = None
local_data = list(zip(*list_of_parts)) local_data = list(zip(*list_of_parts))
data = local_data[0] data = local_data[0]
@ -348,8 +373,15 @@ def _get_worker_parts(worker_map, meta_names, worker):
labels = part labels = part
if meta_names[i] == 'weights': if meta_names[i] == 'weights':
weights = part weights = part
if meta_names[i] == 'base_margin':
base_margin = part
if meta_names[i] == 'label_lower_bound':
label_lower_bound = part
if meta_names[i] == 'label_upper_bound':
label_upper_bound = part
return data, labels, weights return (data, labels, weights, base_margin, label_lower_bound,
label_upper_bound)
class DaskPartitionIter(DataIter): # pylint: disable=R0902 class DaskPartitionIter(DataIter): # pylint: disable=R0902
@ -456,13 +488,22 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
''' '''
def __init__(self, client, data, label=None, weight=None, def __init__(self, client,
data,
label=None,
missing=None, missing=None,
weight=None,
base_margin=None,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None, feature_names=None,
feature_types=None, feature_types=None,
max_bin=256): max_bin=256):
super().__init__(client=client, data=data, label=label, weight=weight, super().__init__(client=client, data=data, label=label,
missing=missing, missing=missing,
weight=weight, base_margin=base_margin,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types) feature_types=feature_types)
self.max_bin = max_bin self.max_bin = max_bin
@ -491,8 +532,13 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
max_bin=max_bin) max_bin=max_bin)
return d return d
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker) (data, labels, weights, base_margin,
it = DaskPartitionIter(data=data, label=labels, weight=weights) label_lower_bound, label_upper_bound) = _get_worker_parts(
worker_map, meta_names, worker)
it = DaskPartitionIter(data=data, label=labels, weight=weights,
base_margin=base_margin,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
dmatrix = DeviceQuantileDMatrix(it, dmatrix = DeviceQuantileDMatrix(it,
missing=missing, missing=missing,
@ -524,20 +570,31 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
feature_types=feature_types) feature_types=feature_types)
return d return d
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker) def concat_or_none(data):
data = concat(data) if data is not None:
return concat(data)
return data
if labels: (data, labels, weights, base_margin,
labels = concat(labels) label_lower_bound, label_upper_bound) = _get_worker_parts(
if weights: worker_map, meta_names, worker)
weights = concat(weights)
labels = concat_or_none(labels)
weights = concat_or_none(weights)
base_margin = concat_or_none(base_margin)
label_lower_bound = concat_or_none(label_lower_bound)
label_upper_bound = concat_or_none(label_upper_bound)
data = concat(data)
dmatrix = DMatrix(data, dmatrix = DMatrix(data,
labels, labels,
weight=weights,
missing=missing, missing=missing,
feature_names=feature_names, feature_names=feature_names,
feature_types=feature_types, feature_types=feature_types,
nthread=worker.nthreads) nthread=worker.nthreads)
dmatrix.set_info(base_margin=base_margin, weight=weights,
label_lower_bound=label_lower_bound,
label_upper_bound=label_upper_bound)
return dmatrix return dmatrix
@ -683,7 +740,8 @@ async def _direct_predict_impl(client, data, predict_fn):
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs): async def _predict_async(client: Client, model, data, missing=numpy.nan,
**kwargs):
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model booster = model

View File

@ -3,9 +3,10 @@ import pytest
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import json import json
from pathlib import Path import os
dpath = os.path.join(tm.PROJECT_ROOT, 'demo', 'data')
dpath = Path('demo/data')
def test_aft_survival_toy_data(): def test_aft_survival_toy_data():
# See demo/aft_survival/aft_survival_viz_demo.py # See demo/aft_survival/aft_survival_viz_demo.py
@ -51,10 +52,10 @@ def test_aft_survival_toy_data():
for tree in model_json: for tree in model_json:
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5}) assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
@pytest.mark.skipif(**tm.no_pandas()) @pytest.mark.skipif(**tm.no_pandas())
def test_aft_survival_demo_data(): def test_aft_survival_demo_data():
import pandas as pd import pandas as pd
df = pd.read_csv(dpath / 'veterans_lung_cancer.csv') df = pd.read_csv(os.path.join(dpath, 'veterans_lung_cancer.csv'))
y_lower_bound = df['Survival_label_lower_bound'] y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound'] y_upper_bound = df['Survival_label_upper_bound']

View File

@ -1,6 +1,5 @@
import testing as tm import testing as tm
import pytest import pytest
import unittest
import xgboost as xgb import xgboost as xgb
import sys import sys
import numpy as np import numpy as np
@ -482,16 +481,62 @@ def test_predict():
assert pred.ndim == 1 assert pred.ndim == 1
assert pred.shape[0] == kRows assert pred.shape[0] == kRows
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True) margin = xgb.dask.predict(client, model=booster, data=dtrain,
output_margin=True)
assert margin.ndim == 1 assert margin.ndim == 1
assert margin.shape[0] == kRows assert margin.shape[0] == kRows
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True) shap = xgb.dask.predict(client, model=booster, data=dtrain,
pred_contribs=True)
assert shap.ndim == 2 assert shap.ndim == 2
assert shap.shape[0] == kRows assert shap.shape[0] == kRows
assert shap.shape[1] == kCols + 1 assert shap.shape[1] == kCols + 1
def run_aft_survival(client, dmatrix_t):
# survival doesn't handle empty dataset well.
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
'veterans_lung_cancer.csv'))
y_lower_bound = df['Survival_label_lower_bound']
y_upper_bound = df['Survival_label_upper_bound']
X = df.drop(['Survival_label_lower_bound',
'Survival_label_upper_bound'], axis=1)
m = dmatrix_t(client, X, label_lower_bound=y_lower_bound,
label_upper_bound=y_upper_bound)
base_params = {'verbosity': 0,
'objective': 'survival:aft',
'eval_metric': 'aft-nloglik',
'learning_rate': 0.05,
'aft_loss_distribution_scale': 1.20,
'max_depth': 6,
'lambda': 0.01,
'alpha': 0.02}
nloglik_rec = {}
dists = ['normal', 'logistic', 'extreme']
for dist in dists:
params = base_params
params.update({'aft_loss_distribution': dist})
evals_result = {}
out = xgb.dask.train(client, params, m, num_boost_round=100,
evals=[(m, 'train')])
evals_result = out['history']
nloglik_rec[dist] = evals_result['train']['aft-nloglik']
# AFT metric (negative log likelihood) improve monotonically
assert all(p >= q for p, q in zip(nloglik_rec[dist],
nloglik_rec[dist][:1]))
# For this data, normal distribution works the best
assert nloglik_rec['normal'][-1] < 4.9
assert nloglik_rec['logistic'][-1] > 4.9
assert nloglik_rec['extreme'][-1] > 4.9
def test_aft_survival():
with LocalCluster(n_workers=1) as cluster:
with Client(cluster) as client:
run_aft_survival(client, DaskDMatrix)
class TestWithDask: class TestWithDask:
def run_updater_test(self, client, params, num_rounds, dataset, def run_updater_test(self, client, params, num_rounds, dataset,
tree_method): tree_method):