[dask] Support more meta data on functional interface. (#6132)
* Add base_margin, label_(lower|upper)_bound. * Test survival training with dask.
This commit is contained in:
parent
7065779afa
commit
33d80ffad0
@ -320,7 +320,8 @@ class DataIter:
|
|||||||
def data_handle(data, label=None, weight=None, base_margin=None,
|
def data_handle(data, label=None, weight=None, base_margin=None,
|
||||||
group=None,
|
group=None,
|
||||||
label_lower_bound=None, label_upper_bound=None,
|
label_lower_bound=None, label_upper_bound=None,
|
||||||
feature_names=None, feature_types=None):
|
feature_names=None, feature_types=None,
|
||||||
|
feature_weights=None):
|
||||||
from .data import dispatch_device_quantile_dmatrix_set_data
|
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||||
from .data import _device_quantile_transform
|
from .data import _device_quantile_transform
|
||||||
data, feature_names, feature_types = _device_quantile_transform(
|
data, feature_names, feature_types = _device_quantile_transform(
|
||||||
@ -333,7 +334,8 @@ class DataIter:
|
|||||||
label_lower_bound=label_lower_bound,
|
label_lower_bound=label_lower_bound,
|
||||||
label_upper_bound=label_upper_bound,
|
label_upper_bound=label_upper_bound,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types)
|
feature_types=feature_types,
|
||||||
|
feature_weights=feature_weights)
|
||||||
try:
|
try:
|
||||||
# Differ the exception in order to return 0 and stop the iteration.
|
# Differ the exception in order to return 0 and stop the iteration.
|
||||||
# Exception inside a ctype callback function has no effect except
|
# Exception inside a ctype callback function has no effect except
|
||||||
|
|||||||
@ -178,6 +178,12 @@ class DaskDMatrix:
|
|||||||
to be present as a missing value. If None, defaults to np.nan.
|
to be present as a missing value. If None, defaults to np.nan.
|
||||||
weight : dask.array.Array/dask.dataframe.DataFrame
|
weight : dask.array.Array/dask.dataframe.DataFrame
|
||||||
Weight for each instance.
|
Weight for each instance.
|
||||||
|
base_margin : dask.array.Array/dask.dataframe.DataFrame
|
||||||
|
Global bias for each instance.
|
||||||
|
label_lower_bound : dask.array.Array/dask.dataframe.DataFrame
|
||||||
|
Upper bound for survival training.
|
||||||
|
label_upper_bound : dask.array.Array/dask.dataframe.DataFrame
|
||||||
|
Lower bound for survival training.
|
||||||
feature_names : list, optional
|
feature_names : list, optional
|
||||||
Set names for features.
|
Set names for features.
|
||||||
feature_types : list, optional
|
feature_types : list, optional
|
||||||
@ -191,6 +197,9 @@ class DaskDMatrix:
|
|||||||
label=None,
|
label=None,
|
||||||
missing=None,
|
missing=None,
|
||||||
weight=None,
|
weight=None,
|
||||||
|
base_margin=None,
|
||||||
|
label_lower_bound=None,
|
||||||
|
label_upper_bound=None,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None):
|
feature_types=None):
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
@ -216,12 +225,17 @@ class DaskDMatrix:
|
|||||||
self.is_quantile = False
|
self.is_quantile = False
|
||||||
|
|
||||||
self._init = client.sync(self.map_local_data,
|
self._init = client.sync(self.map_local_data,
|
||||||
client, data, label, weight)
|
client, data, label=label, weights=weight,
|
||||||
|
base_margin=base_margin,
|
||||||
|
label_lower_bound=label_lower_bound,
|
||||||
|
label_upper_bound=label_upper_bound)
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self):
|
||||||
return self._init.__await__()
|
return self._init.__await__()
|
||||||
|
|
||||||
async def map_local_data(self, client, data, label=None, weights=None):
|
async def map_local_data(self, client, data, label=None, weights=None,
|
||||||
|
base_margin=None,
|
||||||
|
label_lower_bound=None, label_upper_bound=None):
|
||||||
'''Obtain references to local data.'''
|
'''Obtain references to local data.'''
|
||||||
|
|
||||||
def inconsistent(left, left_name, right, right_name):
|
def inconsistent(left, left_name, right, right_name):
|
||||||
@ -241,10 +255,10 @@ class DaskDMatrix:
|
|||||||
' chunks=(partition_size, X.shape[1])'
|
' chunks=(partition_size, X.shape[1])'
|
||||||
|
|
||||||
data = data.persist()
|
data = data.persist()
|
||||||
if label is not None:
|
for meta in [label, weights, base_margin, label_lower_bound,
|
||||||
label = label.persist()
|
label_upper_bound]:
|
||||||
if weights is not None:
|
if meta is not None:
|
||||||
weights = weights.persist()
|
meta = meta.persist()
|
||||||
# Breaking data into partitions, a trick borrowed from dask_xgboost.
|
# Breaking data into partitions, a trick borrowed from dask_xgboost.
|
||||||
|
|
||||||
# `to_delayed` downgrades high-level objects into numpy or pandas
|
# `to_delayed` downgrades high-level objects into numpy or pandas
|
||||||
@ -254,29 +268,37 @@ class DaskDMatrix:
|
|||||||
check_columns(X_parts)
|
check_columns(X_parts)
|
||||||
X_parts = X_parts.flatten().tolist()
|
X_parts = X_parts.flatten().tolist()
|
||||||
|
|
||||||
if label is not None:
|
def flatten_meta(meta):
|
||||||
y_parts = label.to_delayed()
|
if meta is not None:
|
||||||
if isinstance(y_parts, numpy.ndarray):
|
meta_parts = meta.to_delayed()
|
||||||
check_columns(y_parts)
|
if isinstance(meta_parts, numpy.ndarray):
|
||||||
y_parts = y_parts.flatten().tolist()
|
check_columns(meta_parts)
|
||||||
if weights is not None:
|
meta_parts = meta_parts.flatten().tolist()
|
||||||
w_parts = weights.to_delayed()
|
return meta_parts
|
||||||
if isinstance(w_parts, numpy.ndarray):
|
return None
|
||||||
check_columns(w_parts)
|
|
||||||
w_parts = w_parts.flatten().tolist()
|
y_parts = flatten_meta(label)
|
||||||
|
w_parts = flatten_meta(weights)
|
||||||
|
margin_parts = flatten_meta(base_margin)
|
||||||
|
ll_parts = flatten_meta(label_lower_bound)
|
||||||
|
lu_parts = flatten_meta(label_upper_bound)
|
||||||
|
|
||||||
parts = [X_parts]
|
parts = [X_parts]
|
||||||
meta_names = []
|
meta_names = []
|
||||||
if label is not None:
|
|
||||||
assert len(X_parts) == len(
|
def append_meta(m_parts, name: str):
|
||||||
y_parts), inconsistent(X_parts, 'X', y_parts, 'labels')
|
if m_parts is not None:
|
||||||
parts.append(y_parts)
|
assert len(X_parts) == len(
|
||||||
meta_names.append('labels')
|
m_parts), inconsistent(X_parts, 'X', m_parts, name)
|
||||||
if weights is not None:
|
parts.append(m_parts)
|
||||||
assert len(X_parts) == len(
|
meta_names.append(name)
|
||||||
w_parts), inconsistent(X_parts, 'X', w_parts, 'weights')
|
|
||||||
parts.append(w_parts)
|
append_meta(y_parts, 'labels')
|
||||||
meta_names.append('weights')
|
append_meta(w_parts, 'weights')
|
||||||
|
append_meta(margin_parts, 'base_margin')
|
||||||
|
append_meta(ll_parts, 'label_lower_bound')
|
||||||
|
append_meta(lu_parts, 'label_upper_bound')
|
||||||
|
|
||||||
parts = list(map(delayed, zip(*parts)))
|
parts = list(map(delayed, zip(*parts)))
|
||||||
|
|
||||||
parts = client.compute(parts)
|
parts = client.compute(parts)
|
||||||
@ -339,6 +361,9 @@ def _get_worker_parts(worker_map, meta_names, worker):
|
|||||||
data = None
|
data = None
|
||||||
labels = None
|
labels = None
|
||||||
weights = None
|
weights = None
|
||||||
|
base_margin = None
|
||||||
|
label_lower_bound = None
|
||||||
|
label_upper_bound = None
|
||||||
|
|
||||||
local_data = list(zip(*list_of_parts))
|
local_data = list(zip(*list_of_parts))
|
||||||
data = local_data[0]
|
data = local_data[0]
|
||||||
@ -348,8 +373,15 @@ def _get_worker_parts(worker_map, meta_names, worker):
|
|||||||
labels = part
|
labels = part
|
||||||
if meta_names[i] == 'weights':
|
if meta_names[i] == 'weights':
|
||||||
weights = part
|
weights = part
|
||||||
|
if meta_names[i] == 'base_margin':
|
||||||
|
base_margin = part
|
||||||
|
if meta_names[i] == 'label_lower_bound':
|
||||||
|
label_lower_bound = part
|
||||||
|
if meta_names[i] == 'label_upper_bound':
|
||||||
|
label_upper_bound = part
|
||||||
|
|
||||||
return data, labels, weights
|
return (data, labels, weights, base_margin, label_lower_bound,
|
||||||
|
label_upper_bound)
|
||||||
|
|
||||||
|
|
||||||
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
class DaskPartitionIter(DataIter): # pylint: disable=R0902
|
||||||
@ -456,13 +488,22 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
|||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(self, client, data, label=None, weight=None,
|
def __init__(self, client,
|
||||||
|
data,
|
||||||
|
label=None,
|
||||||
missing=None,
|
missing=None,
|
||||||
|
weight=None,
|
||||||
|
base_margin=None,
|
||||||
|
label_lower_bound=None,
|
||||||
|
label_upper_bound=None,
|
||||||
feature_names=None,
|
feature_names=None,
|
||||||
feature_types=None,
|
feature_types=None,
|
||||||
max_bin=256):
|
max_bin=256):
|
||||||
super().__init__(client=client, data=data, label=label, weight=weight,
|
super().__init__(client=client, data=data, label=label,
|
||||||
missing=missing,
|
missing=missing,
|
||||||
|
weight=weight, base_margin=base_margin,
|
||||||
|
label_lower_bound=label_lower_bound,
|
||||||
|
label_upper_bound=label_upper_bound,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types)
|
feature_types=feature_types)
|
||||||
self.max_bin = max_bin
|
self.max_bin = max_bin
|
||||||
@ -491,8 +532,13 @@ def _create_device_quantile_dmatrix(feature_names, feature_types,
|
|||||||
max_bin=max_bin)
|
max_bin=max_bin)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
|
(data, labels, weights, base_margin,
|
||||||
it = DaskPartitionIter(data=data, label=labels, weight=weights)
|
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||||
|
worker_map, meta_names, worker)
|
||||||
|
it = DaskPartitionIter(data=data, label=labels, weight=weights,
|
||||||
|
base_margin=base_margin,
|
||||||
|
label_lower_bound=label_lower_bound,
|
||||||
|
label_upper_bound=label_upper_bound)
|
||||||
|
|
||||||
dmatrix = DeviceQuantileDMatrix(it,
|
dmatrix = DeviceQuantileDMatrix(it,
|
||||||
missing=missing,
|
missing=missing,
|
||||||
@ -524,20 +570,31 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
|||||||
feature_types=feature_types)
|
feature_types=feature_types)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
data, labels, weights = _get_worker_parts(worker_map, meta_names, worker)
|
def concat_or_none(data):
|
||||||
data = concat(data)
|
if data is not None:
|
||||||
|
return concat(data)
|
||||||
|
return data
|
||||||
|
|
||||||
if labels:
|
(data, labels, weights, base_margin,
|
||||||
labels = concat(labels)
|
label_lower_bound, label_upper_bound) = _get_worker_parts(
|
||||||
if weights:
|
worker_map, meta_names, worker)
|
||||||
weights = concat(weights)
|
|
||||||
|
labels = concat_or_none(labels)
|
||||||
|
weights = concat_or_none(weights)
|
||||||
|
base_margin = concat_or_none(base_margin)
|
||||||
|
label_lower_bound = concat_or_none(label_lower_bound)
|
||||||
|
label_upper_bound = concat_or_none(label_upper_bound)
|
||||||
|
|
||||||
|
data = concat(data)
|
||||||
dmatrix = DMatrix(data,
|
dmatrix = DMatrix(data,
|
||||||
labels,
|
labels,
|
||||||
weight=weights,
|
|
||||||
missing=missing,
|
missing=missing,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
nthread=worker.nthreads)
|
nthread=worker.nthreads)
|
||||||
|
dmatrix.set_info(base_margin=base_margin, weight=weights,
|
||||||
|
label_lower_bound=label_lower_bound,
|
||||||
|
label_upper_bound=label_upper_bound)
|
||||||
return dmatrix
|
return dmatrix
|
||||||
|
|
||||||
|
|
||||||
@ -683,7 +740,8 @@ async def _direct_predict_impl(client, data, predict_fn):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-statements
|
# pylint: disable=too-many-statements
|
||||||
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
|
async def _predict_async(client: Client, model, data, missing=numpy.nan,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
if isinstance(model, Booster):
|
if isinstance(model, Booster):
|
||||||
booster = model
|
booster = model
|
||||||
|
|||||||
@ -3,9 +3,10 @@ import pytest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
import os
|
||||||
|
|
||||||
|
dpath = os.path.join(tm.PROJECT_ROOT, 'demo', 'data')
|
||||||
|
|
||||||
dpath = Path('demo/data')
|
|
||||||
|
|
||||||
def test_aft_survival_toy_data():
|
def test_aft_survival_toy_data():
|
||||||
# See demo/aft_survival/aft_survival_viz_demo.py
|
# See demo/aft_survival/aft_survival_viz_demo.py
|
||||||
@ -51,10 +52,10 @@ def test_aft_survival_toy_data():
|
|||||||
for tree in model_json:
|
for tree in model_json:
|
||||||
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
|
assert gather_split_thresholds(tree).issubset({2.5, 3.5, 4.5})
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_pandas())
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
def test_aft_survival_demo_data():
|
def test_aft_survival_demo_data():
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
df = pd.read_csv(dpath / 'veterans_lung_cancer.csv')
|
df = pd.read_csv(os.path.join(dpath, 'veterans_lung_cancer.csv'))
|
||||||
|
|
||||||
y_lower_bound = df['Survival_label_lower_bound']
|
y_lower_bound = df['Survival_label_lower_bound']
|
||||||
y_upper_bound = df['Survival_label_upper_bound']
|
y_upper_bound = df['Survival_label_upper_bound']
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import testing as tm
|
import testing as tm
|
||||||
import pytest
|
import pytest
|
||||||
import unittest
|
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -482,16 +481,62 @@ def test_predict():
|
|||||||
assert pred.ndim == 1
|
assert pred.ndim == 1
|
||||||
assert pred.shape[0] == kRows
|
assert pred.shape[0] == kRows
|
||||||
|
|
||||||
margin = xgb.dask.predict(client, model=booster, data=dtrain, output_margin=True)
|
margin = xgb.dask.predict(client, model=booster, data=dtrain,
|
||||||
|
output_margin=True)
|
||||||
assert margin.ndim == 1
|
assert margin.ndim == 1
|
||||||
assert margin.shape[0] == kRows
|
assert margin.shape[0] == kRows
|
||||||
|
|
||||||
shap = xgb.dask.predict(client, model=booster, data=dtrain, pred_contribs=True)
|
shap = xgb.dask.predict(client, model=booster, data=dtrain,
|
||||||
|
pred_contribs=True)
|
||||||
assert shap.ndim == 2
|
assert shap.ndim == 2
|
||||||
assert shap.shape[0] == kRows
|
assert shap.shape[0] == kRows
|
||||||
assert shap.shape[1] == kCols + 1
|
assert shap.shape[1] == kCols + 1
|
||||||
|
|
||||||
|
|
||||||
|
def run_aft_survival(client, dmatrix_t):
|
||||||
|
# survival doesn't handle empty dataset well.
|
||||||
|
df = dd.read_csv(os.path.join(tm.PROJECT_ROOT, 'demo', 'data',
|
||||||
|
'veterans_lung_cancer.csv'))
|
||||||
|
y_lower_bound = df['Survival_label_lower_bound']
|
||||||
|
y_upper_bound = df['Survival_label_upper_bound']
|
||||||
|
X = df.drop(['Survival_label_lower_bound',
|
||||||
|
'Survival_label_upper_bound'], axis=1)
|
||||||
|
m = dmatrix_t(client, X, label_lower_bound=y_lower_bound,
|
||||||
|
label_upper_bound=y_upper_bound)
|
||||||
|
base_params = {'verbosity': 0,
|
||||||
|
'objective': 'survival:aft',
|
||||||
|
'eval_metric': 'aft-nloglik',
|
||||||
|
'learning_rate': 0.05,
|
||||||
|
'aft_loss_distribution_scale': 1.20,
|
||||||
|
'max_depth': 6,
|
||||||
|
'lambda': 0.01,
|
||||||
|
'alpha': 0.02}
|
||||||
|
|
||||||
|
nloglik_rec = {}
|
||||||
|
dists = ['normal', 'logistic', 'extreme']
|
||||||
|
for dist in dists:
|
||||||
|
params = base_params
|
||||||
|
params.update({'aft_loss_distribution': dist})
|
||||||
|
evals_result = {}
|
||||||
|
out = xgb.dask.train(client, params, m, num_boost_round=100,
|
||||||
|
evals=[(m, 'train')])
|
||||||
|
evals_result = out['history']
|
||||||
|
nloglik_rec[dist] = evals_result['train']['aft-nloglik']
|
||||||
|
# AFT metric (negative log likelihood) improve monotonically
|
||||||
|
assert all(p >= q for p, q in zip(nloglik_rec[dist],
|
||||||
|
nloglik_rec[dist][:1]))
|
||||||
|
# For this data, normal distribution works the best
|
||||||
|
assert nloglik_rec['normal'][-1] < 4.9
|
||||||
|
assert nloglik_rec['logistic'][-1] > 4.9
|
||||||
|
assert nloglik_rec['extreme'][-1] > 4.9
|
||||||
|
|
||||||
|
|
||||||
|
def test_aft_survival():
|
||||||
|
with LocalCluster(n_workers=1) as cluster:
|
||||||
|
with Client(cluster) as client:
|
||||||
|
run_aft_survival(client, DaskDMatrix)
|
||||||
|
|
||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
def run_updater_test(self, client, params, num_rounds, dataset,
|
def run_updater_test(self, client, params, num_rounds, dataset,
|
||||||
tree_method):
|
tree_method):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user