Fix prediction on df with latest dask. (#6969)

This commit is contained in:
Jiaming Yuan 2021-05-19 12:23:03 +08:00 committed by GitHub
parent 6e104f0570
commit 7e846bb965
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1000,7 +1000,7 @@ async def _direct_predict_impl( # pylint: disable=too-many-branches
output_shape: Tuple[int, ...],
meta: Dict[int, str],
) -> _DaskCollection:
columns = list(meta.keys())
columns = tuple(meta.keys())
if len(output_shape) >= 3 and isinstance(data, dd.DataFrame):
# Without this check, dask will finish the prediction silently even if output
# dimension is greater than 3. But during map_partitions, dask passes a