Lazy import dask libraries. (#6309)

* Lazy import dask libraries.

* Lint && fix.

* Use short name.
This commit is contained in:
Jiaming Yuan 2020-10-29 06:50:11 +08:00 committed by GitHub
parent dfac5f89e9
commit 74ea82209b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 49 deletions

View File

@ -1,11 +1,10 @@
# coding: utf-8
# pylint: disable= invalid-name, unused-import
"""For compatibility and optional dependencies."""
import abc
import os
import sys
from pathlib import PurePath
import types
import importlib.util
import logging
import numpy as np
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
@ -120,3 +119,60 @@ except ImportError:
sparse = False
scipy_sparse = False
SCIPY_INSTALLED = False
# Modified from tensorflow with added caching. There's a `LazyLoader` in
# `importlib.utils`, except it's unclear from its document on how to use it. This one
# seems to be easy to understand and works out of box.
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
# file except in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the specific language governing
# permissions and limitations under the License.
class LazyLoader(types.ModuleType):
"""Lazily import a module, mainly to avoid pulling in large dependencies.
"""
def __init__(self, local_name, parent_module_globals, name, warning=None):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._warning = warning
self.module = None
super().__init__(name)
def _load(self):
"""Load the module and insert it into the parent's globals."""
# Import the target module and insert it into the parent's namespace
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# Emit a warning if one was specified
if self._warning:
logging.warning(self._warning)
# Make sure to only warn once.
self._warning = None
# Update this object's dict so that if someone keeps a reference to the
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
# that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item):
if not self.module:
self.module = self._load()
return getattr(self.module, item)
def __dir__(self):
if not self.module:
self.module = self._load()
return dir(self.module)

View File

@ -23,7 +23,7 @@ import numpy
from . import rabit
from .compat import DASK_INSTALLED
from .compat import LazyLoader
from .compat import sparse, scipy_sparse
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
from .compat import CUDF_concat
@ -35,23 +35,11 @@ from .tracker import RabitTracker
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
from .sklearn import xgboost_model_doc
try:
from dask.distributed import Client, get_client
from dask.distributed import comm as distributed_comm
from dask.distributed import wait as distributed_wait
from dask.distributed import get_worker as distributed_get_worker
from dask import dataframe as dd
from dask import array as da
from dask import delayed
except ImportError:
Client = None
get_client = None
distributed_comm = None
distributed_wait = None
distributed_get_worker = None
dd = None
da = None
delayed = None
dd = LazyLoader('dd', globals(), 'dask.dataframe')
da = LazyLoader('da', globals(), 'dask.array')
dask = LazyLoader('dask', globals(), 'dask')
distributed = LazyLoader('distributed', globals(), 'dask.distributed')
# Current status is considered as initial support, many features are
# not properly supported yet.
@ -91,12 +79,12 @@ def _start_tracker(host, n_workers):
def _assert_dask_support():
if not DASK_INSTALLED:
try:
import dask # pylint: disable=W0621,W0611
except ImportError as e:
raise ImportError(
'Dask needs to be installed in order to use this module')
if not distributed_wait:
raise ImportError(
'distributed needs to be installed in order to use this module.')
'Dask needs to be installed in order to use this module') from e
if platform.system() == 'Windows':
msg = 'Windows is not officially supported for dask/xgboost,'
msg += ' contribution are welcomed.'
@ -107,7 +95,7 @@ class RabitContext:
'''A context controling rabit initialization and finalization.'''
def __init__(self, args):
self.args = args
worker = distributed_get_worker()
worker = distributed.get_worker()
self.args.append(
('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode())
@ -146,10 +134,10 @@ def concat(value): # pylint: disable=too-many-return-statements
def _xgb_get_client(client):
'''Simple wrapper around testing None.'''
if not isinstance(client, (type(get_client()), type(None))):
if not isinstance(client, (type(distributed.get_client()), type(None))):
raise TypeError(
_expect([type(get_client()), type(None)], type(client)))
ret = get_client() if client is None else client
_expect([type(distributed.get_client()), type(None)], type(client)))
ret = distributed.get_client() if client is None else client
return ret
@ -217,7 +205,7 @@ class DaskDMatrix:
feature_names=None,
feature_types=None):
_assert_dask_support()
client: Client = _xgb_get_client(client)
client: distributed.Client = _xgb_get_client(client)
self.feature_names = feature_names
self.feature_types = feature_types
@ -313,10 +301,10 @@ class DaskDMatrix:
append_meta(ll_parts, 'label_lower_bound')
append_meta(lu_parts, 'label_upper_bound')
parts = list(map(delayed, zip(*parts)))
parts = list(map(dask.delayed, zip(*parts)))
parts = client.compute(parts)
await distributed_wait(parts) # async wait for parts to be computed
await distributed.wait(parts) # async wait for parts to be computed
for part in parts:
assert part.status == 'finished'
@ -354,7 +342,7 @@ class DaskDMatrix:
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
worker):
list_of_parts = worker_map[worker.address]
client = get_client()
client = distributed.get_client()
list_of_parts_value = client.gather(list_of_parts)
result = []
@ -378,7 +366,7 @@ def _get_worker_parts(worker_map, meta_names, worker):
# `_get_worker_parts` is launched inside worker. In dask side
# this should be equal to `worker._get_client`.
client = get_client()
client = distributed.get_client()
list_of_parts = client.gather(list_of_parts)
data = None
labels = None
@ -544,7 +532,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
def _create_device_quantile_dmatrix(feature_names, feature_types,
meta_names, missing, worker_map,
max_bin):
worker = distributed_get_worker()
worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
@ -584,7 +572,7 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
A DMatrix object.
'''
worker = distributed_get_worker()
worker = distributed.get_worker()
if worker.address not in set(worker_map.keys()):
msg = 'worker {address} has an empty DMatrix. ' \
'All workers associated with this DMatrix: {workers}'.format(
@ -630,9 +618,9 @@ def _dmatrix_from_worker_map(is_quantile, **kwargs):
return _create_dmatrix(**kwargs)
async def _get_rabit_args(worker_map, client: Client):
async def _get_rabit_args(worker_map, client):
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
host = distributed_comm.get_address_host(client.scheduler.address)
host = distributed.comm.get_address_host(client.scheduler.address)
env = await client.run_on_scheduler(
_start_tracker, host.strip('/:'), len(worker_map))
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
@ -648,7 +636,7 @@ async def _get_rabit_args(worker_map, client: Client):
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
early_stopping_rounds=None, **kwargs):
_assert_dask_support()
client: Client = _xgb_get_client(client)
client: distributed.Client = _xgb_get_client(client)
if 'evals_result' in kwargs.keys():
raise ValueError(
'evals_result is not supported in dask interface.',
@ -662,7 +650,7 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
'''
LOGGER.info('Training on %s', str(worker_addr))
worker = distributed_get_worker()
worker = distributed.get_worker()
with RabitContext(rabit_args):
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
local_evals = []
@ -774,7 +762,7 @@ async def _direct_predict_impl(client, data, predict_fn):
# pylint: disable=too-many-statements
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
if isinstance(model, Booster):
booster = model
elif isinstance(model, dict):
@ -786,7 +774,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
type(data)))
def mapped_predict(partition, is_df):
worker = distributed_get_worker()
worker = distributed.get_worker()
booster.set_param({'nthread': worker.nthreads})
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
predt = booster.predict(m, validate_features=False, **kwargs)
@ -813,7 +801,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
'''Perform prediction on each worker.'''
LOGGER.info('Predicting on %d', worker_id)
worker = distributed_get_worker()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
has_margin, worker_map, partition_order, worker)
predictions = []
@ -832,14 +820,14 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
validate_features=local_part.num_row() != 0,
**kwargs)
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
ret = ((delayed(predt), columns), order)
ret = ((dask.delayed(predt), columns), order)
predictions.append(ret)
return predictions
def dispatched_get_shape(worker_id):
'''Get shape of data in each worker.'''
LOGGER.info('Get shape on %d', worker_id)
worker = distributed_get_worker()
worker = distributed.get_worker()
list_of_parts = _get_worker_parts_ordered(
False,
worker_map,
@ -930,7 +918,7 @@ async def _inplace_predict_async(client, model, data,
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
def mapped_predict(data, is_df):
worker = distributed_get_worker()
worker = distributed.get_worker()
booster.set_param({'nthread': worker.nthreads})
prediction = booster.inplace_predict(
data,
@ -1072,7 +1060,7 @@ class DaskScikitLearnBase(XGBModel):
return self.client.sync(_).__await__()
@property
def client(self) -> Client:
def client(self):
'''The dask client used in this model.'''
client = _xgb_get_client(self._client)
return client