Fix dask prediction. (#4941)

* Fix dask prediction.

* Add better error messages for wrong partition.
This commit is contained in:
Jiaming Yuan 2019-10-14 23:19:34 -04:00 committed by GitHub
parent b61d534472
commit 2ebdec8aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 24 deletions

View File

@ -7,11 +7,10 @@ from dask import array as da
def main(client):
# generate some random data for demonstration
n = 100
m = 100000
partition_size = 1000
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
n = 100
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers.
@ -38,6 +37,6 @@ def main(client):
if __name__ == '__main__':
# or use other clusters for scaling
with LocalCluster(n_workers=4, threads_per_worker=1) as cluster:
with LocalCluster(n_workers=7, threads_per_worker=1) as cluster:
with Client(cluster) as client:
main(client)

View File

@ -6,11 +6,11 @@ from xgboost.dask import DaskDMatrix
def main(client):
n = 100
# generate some random data for demonstration
m = 100000
partition_size = 1000
X = da.random.random((m, n), partition_size)
y = da.random.random(m, partition_size)
n = 100
X = da.random.random(size=(m, n), chunks=100)
y = da.random.random(size=(m, ), chunks=100)
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers.
@ -23,6 +23,7 @@ def main(client):
output = xgb.dask.train(client,
{'verbosity': 2,
'nthread': 1,
# Golden line for GPU training
'tree_method': 'gpu_hist'},
dtrain,
num_boost_round=4, evals=[(dtrain, 'train')])

View File

@ -139,13 +139,14 @@ class DaskDMatrix:
self._missing = missing
if len(data.shape) != 2:
_expect('2 dimensions input', data.shape)
raise ValueError(
'Expecting 2 dimensional input, got: {shape}'.format(
shape=data.shape))
if not any(isinstance(data, t) for t in (dd.DataFrame, da.Array)):
if not isinstance(data, (dd.DataFrame, da.Array)):
raise TypeError(_expect((dd.DataFrame, da.Array), type(data)))
if not any(
isinstance(label, t)
for t in (dd.DataFrame, da.Array, dd.Series, type(None))):
if not isinstance(label, (dd.DataFrame, da.Array, dd.Series,
type(None))):
raise TypeError(
_expect((dd.DataFrame, da.Array, dd.Series), type(label)))
@ -158,6 +159,23 @@ class DaskDMatrix:
async def map_local_data(self, client, data, label=None, weights=None):
'''Obtain references to local data.'''
def inconsistent(left, left_name, right, right_name):
msg = 'Partitions between {a_name} and {b_name} are not ' \
'consistent: {a_len} != {b_len}. ' \
'Please try to repartition/rechunk your data.'.format(
a_name=left_name, b_name=right_name, a_len=len(left),
b_len=len(right)
)
return msg
def check_columns(parts):
# x is required to be 2 dim in __init__
assert parts.ndim == 1 or parts.shape[1], 'Data should be' \
' partitioned by row. To avoid this specify the number' \
' of columns for your dask Array explicitly. e.g.' \
' chunks=(partition_size, X.shape[1])'
data = data.persist()
if label is not None:
label = label.persist()
@ -169,28 +187,28 @@ class DaskDMatrix:
# equivalents.
X_parts = data.to_delayed()
if isinstance(X_parts, numpy.ndarray):
assert X_parts.shape[1] == 1
check_columns(X_parts)
X_parts = X_parts.flatten().tolist()
if label is not None:
y_parts = label.to_delayed()
if isinstance(y_parts, numpy.ndarray):
assert y_parts.ndim == 1 or y_parts.shape[1] == 1
check_columns(y_parts)
y_parts = y_parts.flatten().tolist()
if weights is not None:
w_parts = weights.to_delayed()
if isinstance(w_parts, numpy.ndarray):
assert w_parts.ndim == 1 or w_parts.shape[1] == 1
check_columns(w_parts)
w_parts = w_parts.flatten().tolist()
parts = [X_parts]
if label is not None:
assert len(X_parts) == len(
y_parts), 'Partitions between X and y are not consistent'
y_parts), inconsistent(X_parts, 'X', y_parts, 'labels')
parts.append(y_parts)
if weights is not None:
assert len(X_parts) == len(
w_parts), 'Partitions between X and weight are not consistent.'
w_parts), inconsistent(X_parts, 'X', w_parts, 'weights')
parts.append(w_parts)
parts = list(map(delayed, zip(*parts)))
@ -275,7 +293,11 @@ class DaskDMatrix:
cols = 0
for shape in shapes:
rows += shape[0]
cols += shape[1]
c = shape[1]
assert cols in (0, c), 'Shape between partitions are not the' \
' same. Got: {left} and {right}'.format(left=c, right=cols)
cols = c
return (rows, cols)

View File

@ -185,7 +185,7 @@ void SimpleCSRSource::CopyFrom(std::string const& cuda_interfaces_str,
cuda_interfaces_str.size()});
std::vector<Json> const& columns = get<Array>(interfaces);
size_t n_columns = columns.size();
CHECK_GT(n_columns, 0) << "Number of columns must not be greater than 0.";
CHECK_GT(n_columns, 0) << "Number of columns must not eqaul to 0.";
auto const& typestr = get<String const>(columns[0]["typestr"]);
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();

View File

@ -21,12 +21,12 @@ except ImportError:
pass
kRows = 1000
kCols = 10
def generate_array():
n = 10
partition_size = 20
X = da.random.random((kRows, n), partition_size)
X = da.random.random((kRows, kCols), partition_size)
y = da.random.random(kRows, partition_size)
return X, y
@ -44,7 +44,7 @@ def test_from_dask_dataframe(client):
prediction = xgb.dask.predict(client, model=booster, data=dtrain)
assert isinstance(prediction, da.Array)
assert prediction.shape[0] == kRows, prediction
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
with pytest.raises(ValueError):
# evals_result is not supported in dask interface.
@ -59,6 +59,7 @@ def test_from_dask_array(client):
result = xgb.dask.train(client, {}, dtrain)
prediction = xgb.dask.predict(client, result, dtrain)
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
assert isinstance(prediction, da.Array)
@ -71,6 +72,8 @@ def test_regressor(client):
regressor.fit(X, y, eval_set=[(X, y)])
prediction = regressor.predict(X)
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
history = regressor.evals_result()
assert isinstance(prediction, da.Array)
@ -88,6 +91,8 @@ def test_classifier(client):
classifier.fit(X, y, eval_set=[(X, y)])
prediction = classifier.predict(X)
assert prediction.shape[0] == kRows and prediction.shape[1] == kCols
history = classifier.evals_result()
assert isinstance(prediction, da.Array)