Lazy import dask libraries. (#6309)

* Lazy import dask libraries.

* Lint && fix.

* Use short name.
This commit is contained in:
Jiaming Yuan 2020-10-29 06:50:11 +08:00 committed by GitHub
parent dfac5f89e9
commit 74ea82209b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 49 deletions

View File

@ -1,11 +1,10 @@
# coding: utf-8 # coding: utf-8
# pylint: disable= invalid-name, unused-import # pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies.""" """For compatibility and optional dependencies."""
import abc
import os
import sys import sys
from pathlib import PurePath import types
import importlib.util
import logging
import numpy as np import numpy as np
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
@ -120,3 +119,60 @@ except ImportError:
sparse = False sparse = False
scipy_sparse = False scipy_sparse = False
SCIPY_INSTALLED = False SCIPY_INSTALLED = False
# Modified from tensorflow with added caching. There's a `LazyLoader` in
# `importlib.utils`, except it's unclear from its document on how to use it. This one
# seems to be easy to understand and works out of box.
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
# file except in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the specific language governing
# permissions and limitations under the License.
class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.
"""
def __init__(self, local_name, parent_module_globals, name, warning=None):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._warning = warning
self.module = None
super().__init__(name)
def _load(self):
"""Load the module and insert it into the parent's globals."""
# Import the target module and insert it into the parent's namespace
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# Emit a warning if one was specified
if self._warning:
logging.warning(self._warning)
# Make sure to only warn once.
self._warning = None
# Update this object's dict so that if someone keeps a reference to the
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
# that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item):
if not self.module:
self.module = self._load()
return getattr(self.module, item)
def __dir__(self):
if not self.module:
self.module = self._load()
return dir(self.module)

View File

@ -23,7 +23,7 @@ import numpy
from . import rabit from . import rabit
from .compat import DASK_INSTALLED from .compat import LazyLoader
from .compat import sparse, scipy_sparse from .compat import sparse, scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import CUDF_concat from .compat import CUDF_concat
@ -35,23 +35,11 @@ from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc from .sklearn import xgboost_model_doc
try:
from dask.distributed import Client, get_client dd = LazyLoader('dd', globals(), 'dask.dataframe')
from dask.distributed import comm as distributed_comm da = LazyLoader('da', globals(), 'dask.array')
from dask.distributed import wait as distributed_wait dask = LazyLoader('dask', globals(), 'dask')
from dask.distributed import get_worker as distributed_get_worker distributed = LazyLoader('distributed', globals(), 'dask.distributed')
from dask import dataframe as dd
from dask import array as da
from dask import delayed
except ImportError:
Client = None
get_client = None
distributed_comm = None
distributed_wait = None
distributed_get_worker = None
dd = None
da = None
delayed = None
# Current status is considered as initial support, many features are # Current status is considered as initial support, many features are
# not properly supported yet. # not properly supported yet.
@ -91,12 +79,12 @@ def _start_tracker(host, n_workers):
def _assert_dask_support(): def _assert_dask_support():
if not DASK_INSTALLED: try:
import dask # pylint: disable=W0621,W0611
except ImportError as e:
raise ImportError( raise ImportError(
'Dask needs to be installed in order to use this module') 'Dask needs to be installed in order to use this module') from e
if not distributed_wait:
raise ImportError(
'distributed needs to be installed in order to use this module.')
if platform.system() == 'Windows': if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,' msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.' msg += ' contribution are welcomed.'
@ -107,7 +95,7 @@ class RabitContext:
'''A context controling rabit initialization and finalization.''' '''A context controling rabit initialization and finalization.'''
def __init__(self, args): def __init__(self, args):
self.args = args self.args = args
worker = distributed_get_worker() worker = distributed.get_worker()
self.args.append( self.args.append(
('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode()) ('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode())
@ -146,10 +134,10 @@ def concat(value): # pylint: disable=too-many-return-statements
def _xgb_get_client(client): def _xgb_get_client(client):
'''Simple wrapper around testing None.''' '''Simple wrapper around testing None.'''
if not isinstance(client, (type(get_client()), type(None))): if not isinstance(client, (type(distributed.get_client()), type(None))):
raise TypeError( raise TypeError(
_expect([type(get_client()), type(None)], type(client))) _expect([type(distributed.get_client()), type(None)], type(client)))
ret = get_client() if client is None else client ret = distributed.get_client() if client is None else client
return ret return ret
@ -217,7 +205,7 @@ class DaskDMatrix:
feature_names=None, feature_names=None,
feature_types=None): feature_types=None):
_assert_dask_support() _assert_dask_support()
client: Client = _xgb_get_client(client) client: distributed.Client = _xgb_get_client(client)
self.feature_names = feature_names self.feature_names = feature_names
self.feature_types = feature_types self.feature_types = feature_types
@ -313,10 +301,10 @@ class DaskDMatrix:
append_meta(ll_parts, 'label_lower_bound') append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound') append_meta(lu_parts, 'label_upper_bound')
parts = list(map(delayed, zip(*parts))) parts = list(map(dask.delayed, zip(*parts)))
parts = client.compute(parts) parts = client.compute(parts)
await distributed_wait(parts) # async wait for parts to be computed await distributed.wait(parts) # async wait for parts to be computed
for part in parts: for part in parts:
assert part.status == 'finished' assert part.status == 'finished'
@ -354,7 +342,7 @@ class DaskDMatrix:
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order, def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker): worker):
list_of_parts = worker_map[worker.address] list_of_parts = worker_map[worker.address]
client = get_client() client = distributed.get_client()
list_of_parts_value = client.gather(list_of_parts) list_of_parts_value = client.gather(list_of_parts)
result = [] result = []
@ -378,7 +366,7 @@ def _get_worker_parts(worker_map, meta_names, worker):
# `_get_worker_parts` is launched inside worker. In dask side # `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`. # this should be equal to `worker._get_client`.
client = get_client() client = distributed.get_client()
list_of_parts = client.gather(list_of_parts) list_of_parts = client.gather(list_of_parts)
data = None data = None
labels = None labels = None
@ -544,7 +532,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
def _create_device_quantile_dmatrix(feature_names, feature_types, def _create_device_quantile_dmatrix(feature_names, feature_types,
meta_names, missing, worker_map, meta_names, missing, worker_map,
max_bin): max_bin):
worker = distributed_get_worker() worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()): if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \ msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format( 'All workers associated with this DMatrix: {workers}'.format(
@ -584,7 +572,7 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
A DMatrix object. A DMatrix object.
''' '''
worker = distributed_get_worker() worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()): if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \ msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format( 'All workers associated with this DMatrix: {workers}'.format(
@ -630,9 +618,9 @@ def _dmatrix_from_worker_map(is_quantile, **kwargs):
return _create_dmatrix(**kwargs) return _create_dmatrix(**kwargs)
async def _get_rabit_args(worker_map, client: Client): async def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.''' '''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address) host = distributed.comm.get_address_host(client.scheduler.address)
env = await client.run_on_scheduler( env = await client.run_on_scheduler(
_start_tracker, host.strip('/:'), len(worker_map)) _start_tracker, host.strip('/:'), len(worker_map))
rabit_args = [('%s=%s' % item).encode() for item in env.items()] rabit_args = [('%s=%s' % item).encode() for item in env.items()]
@ -648,7 +636,7 @@ async def _get_rabit_args(worker_map, client: Client):
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(), async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
early_stopping_rounds=None, **kwargs): early_stopping_rounds=None, **kwargs):
_assert_dask_support() _assert_dask_support()
client: Client = _xgb_get_client(client) client: distributed.Client = _xgb_get_client(client)
if 'evals_result' in kwargs.keys(): if 'evals_result' in kwargs.keys():
raise ValueError( raise ValueError(
'evals_result is not supported in dask interface.', 'evals_result is not supported in dask interface.',
@ -662,7 +650,7 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
''' '''
LOGGER.info('Training on %s', str(worker_addr)) LOGGER.info('Training on %s', str(worker_addr))
worker = distributed_get_worker() worker = distributed.get_worker()
with RabitContext(rabit_args): with RabitContext(rabit_args):
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref) local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
local_evals = [] local_evals = []
@ -774,7 +762,7 @@ async def _direct_predict_impl(client, data, predict_fn):
# pylint: disable=too-many-statements # pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs): async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
if isinstance(model, Booster): if isinstance(model, Booster):
booster = model booster = model
elif isinstance(model, dict): elif isinstance(model, dict):
@ -786,7 +774,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
type(data))) type(data)))
def mapped_predict(partition, is_df): def mapped_predict(partition, is_df):
worker = distributed_get_worker() worker = distributed.get_worker()
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads) m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs) predt = booster.predict(m, validate_features=False, **kwargs)
@ -813,7 +801,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
'''Perform prediction on each worker.''' '''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id) LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker() worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered( list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker) has_margin, worker_map, partition_order, worker)
predictions = [] predictions = []
@ -832,14 +820,14 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
validate_features=local_part.num_row() != 0, validate_features=local_part.num_row() != 0,
**kwargs) **kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1] columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order) ret = ((dask.delayed(predt), columns), order)
predictions.append(ret) predictions.append(ret)
return predictions return predictions
def dispatched_get_shape(worker_id): def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.''' '''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id) LOGGER.info('Get shape on %d', worker_id)
worker = distributed_get_worker() worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered( list_of_parts = _get_worker_parts_ordered(
False, False,
worker_map, worker_map,
@ -930,7 +918,7 @@ async def _inplace_predict_async(client, model, data,
raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
def mapped_predict(data, is_df): def mapped_predict(data, is_df):
worker = distributed_get_worker() worker = distributed.get_worker()
booster.set_param({'nthread': worker.nthreads}) booster.set_param({'nthread': worker.nthreads})
prediction = booster.inplace_predict( prediction = booster.inplace_predict(
data, data,
@ -1072,7 +1060,7 @@ class DaskScikitLearnBase(XGBModel):
return self.client.sync(_).__await__() return self.client.sync(_).__await__()
@property @property
def client(self) -> Client: def client(self):
'''The dask client used in this model.''' '''The dask client used in this model.'''
client = _xgb_get_client(self._client) client = _xgb_get_client(self._client)
return client return client