Lazy import dask libraries. (#6309)
* Lazy import dask libraries. * Lint && fix. * Use short name.
This commit is contained in:
parent
dfac5f89e9
commit
74ea82209b
@ -1,11 +1,10 @@
|
||||
# coding: utf-8
|
||||
# pylint: disable= invalid-name, unused-import
|
||||
"""For compatibility and optional dependencies."""
|
||||
import abc
|
||||
import os
|
||||
import sys
|
||||
from pathlib import PurePath
|
||||
|
||||
import types
|
||||
import importlib.util
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.'
|
||||
@ -120,3 +119,60 @@ except ImportError:
|
||||
sparse = False
|
||||
scipy_sparse = False
|
||||
SCIPY_INSTALLED = False
|
||||
|
||||
|
||||
# Modified from tensorflow with added caching. There's a `LazyLoader` in
|
||||
# `importlib.utils`, except it's unclear from its document on how to use it. This one
|
||||
# seems to be easy to understand and works out of box.
|
||||
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this
|
||||
# file except in compliance with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software distributed under
|
||||
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the specific language governing
|
||||
# permissions and limitations under the License.
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
"""
|
||||
|
||||
def __init__(self, local_name, parent_module_globals, name, warning=None):
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
self._warning = warning
|
||||
self.module = None
|
||||
|
||||
super().__init__(name)
|
||||
|
||||
def _load(self):
|
||||
"""Load the module and insert it into the parent's globals."""
|
||||
# Import the target module and insert it into the parent's namespace
|
||||
module = importlib.import_module(self.__name__)
|
||||
self._parent_module_globals[self._local_name] = module
|
||||
|
||||
# Emit a warning if one was specified
|
||||
if self._warning:
|
||||
logging.warning(self._warning)
|
||||
# Make sure to only warn once.
|
||||
self._warning = None
|
||||
|
||||
# Update this object's dict so that if someone keeps a reference to the
|
||||
# LazyLoader, lookups are efficient (__getattr__ is only called on lookups
|
||||
# that fail).
|
||||
self.__dict__.update(module.__dict__)
|
||||
|
||||
return module
|
||||
|
||||
def __getattr__(self, item):
|
||||
if not self.module:
|
||||
self.module = self._load()
|
||||
return getattr(self.module, item)
|
||||
|
||||
def __dir__(self):
|
||||
if not self.module:
|
||||
self.module = self._load()
|
||||
return dir(self.module)
|
||||
|
||||
@ -23,7 +23,7 @@ import numpy
|
||||
|
||||
from . import rabit
|
||||
|
||||
from .compat import DASK_INSTALLED
|
||||
from .compat import LazyLoader
|
||||
from .compat import sparse, scipy_sparse
|
||||
from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat
|
||||
from .compat import CUDF_concat
|
||||
@ -35,23 +35,11 @@ from .tracker import RabitTracker
|
||||
from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase
|
||||
from .sklearn import xgboost_model_doc
|
||||
|
||||
try:
|
||||
from dask.distributed import Client, get_client
|
||||
from dask.distributed import comm as distributed_comm
|
||||
from dask.distributed import wait as distributed_wait
|
||||
from dask.distributed import get_worker as distributed_get_worker
|
||||
from dask import dataframe as dd
|
||||
from dask import array as da
|
||||
from dask import delayed
|
||||
except ImportError:
|
||||
Client = None
|
||||
get_client = None
|
||||
distributed_comm = None
|
||||
distributed_wait = None
|
||||
distributed_get_worker = None
|
||||
dd = None
|
||||
da = None
|
||||
delayed = None
|
||||
|
||||
dd = LazyLoader('dd', globals(), 'dask.dataframe')
|
||||
da = LazyLoader('da', globals(), 'dask.array')
|
||||
dask = LazyLoader('dask', globals(), 'dask')
|
||||
distributed = LazyLoader('distributed', globals(), 'dask.distributed')
|
||||
|
||||
# Current status is considered as initial support, many features are
|
||||
# not properly supported yet.
|
||||
@ -91,12 +79,12 @@ def _start_tracker(host, n_workers):
|
||||
|
||||
|
||||
def _assert_dask_support():
|
||||
if not DASK_INSTALLED:
|
||||
try:
|
||||
import dask # pylint: disable=W0621,W0611
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
'Dask needs to be installed in order to use this module')
|
||||
if not distributed_wait:
|
||||
raise ImportError(
|
||||
'distributed needs to be installed in order to use this module.')
|
||||
'Dask needs to be installed in order to use this module') from e
|
||||
|
||||
if platform.system() == 'Windows':
|
||||
msg = 'Windows is not officially supported for dask/xgboost,'
|
||||
msg += ' contribution are welcomed.'
|
||||
@ -107,7 +95,7 @@ class RabitContext:
|
||||
'''A context controling rabit initialization and finalization.'''
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
self.args.append(
|
||||
('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode())
|
||||
|
||||
@ -146,10 +134,10 @@ def concat(value): # pylint: disable=too-many-return-statements
|
||||
|
||||
def _xgb_get_client(client):
|
||||
'''Simple wrapper around testing None.'''
|
||||
if not isinstance(client, (type(get_client()), type(None))):
|
||||
if not isinstance(client, (type(distributed.get_client()), type(None))):
|
||||
raise TypeError(
|
||||
_expect([type(get_client()), type(None)], type(client)))
|
||||
ret = get_client() if client is None else client
|
||||
_expect([type(distributed.get_client()), type(None)], type(client)))
|
||||
ret = distributed.get_client() if client is None else client
|
||||
return ret
|
||||
|
||||
|
||||
@ -217,7 +205,7 @@ class DaskDMatrix:
|
||||
feature_names=None,
|
||||
feature_types=None):
|
||||
_assert_dask_support()
|
||||
client: Client = _xgb_get_client(client)
|
||||
client: distributed.Client = _xgb_get_client(client)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
@ -313,10 +301,10 @@ class DaskDMatrix:
|
||||
append_meta(ll_parts, 'label_lower_bound')
|
||||
append_meta(lu_parts, 'label_upper_bound')
|
||||
|
||||
parts = list(map(delayed, zip(*parts)))
|
||||
parts = list(map(dask.delayed, zip(*parts)))
|
||||
|
||||
parts = client.compute(parts)
|
||||
await distributed_wait(parts) # async wait for parts to be computed
|
||||
await distributed.wait(parts) # async wait for parts to be computed
|
||||
|
||||
for part in parts:
|
||||
assert part.status == 'finished'
|
||||
@ -354,7 +342,7 @@ class DaskDMatrix:
|
||||
def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order,
|
||||
worker):
|
||||
list_of_parts = worker_map[worker.address]
|
||||
client = get_client()
|
||||
client = distributed.get_client()
|
||||
list_of_parts_value = client.gather(list_of_parts)
|
||||
|
||||
result = []
|
||||
@ -378,7 +366,7 @@ def _get_worker_parts(worker_map, meta_names, worker):
|
||||
|
||||
# `_get_worker_parts` is launched inside worker. In dask side
|
||||
# this should be equal to `worker._get_client`.
|
||||
client = get_client()
|
||||
client = distributed.get_client()
|
||||
list_of_parts = client.gather(list_of_parts)
|
||||
data = None
|
||||
labels = None
|
||||
@ -544,7 +532,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix):
|
||||
def _create_device_quantile_dmatrix(feature_names, feature_types,
|
||||
meta_names, missing, worker_map,
|
||||
max_bin):
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
@ -584,7 +572,7 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing,
|
||||
A DMatrix object.
|
||||
|
||||
'''
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
if worker.address not in set(worker_map.keys()):
|
||||
msg = 'worker {address} has an empty DMatrix. ' \
|
||||
'All workers associated with this DMatrix: {workers}'.format(
|
||||
@ -630,9 +618,9 @@ def _dmatrix_from_worker_map(is_quantile, **kwargs):
|
||||
return _create_dmatrix(**kwargs)
|
||||
|
||||
|
||||
async def _get_rabit_args(worker_map, client: Client):
|
||||
async def _get_rabit_args(worker_map, client):
|
||||
'''Get rabit context arguments from data distribution in DaskDMatrix.'''
|
||||
host = distributed_comm.get_address_host(client.scheduler.address)
|
||||
host = distributed.comm.get_address_host(client.scheduler.address)
|
||||
env = await client.run_on_scheduler(
|
||||
_start_tracker, host.strip('/:'), len(worker_map))
|
||||
rabit_args = [('%s=%s' % item).encode() for item in env.items()]
|
||||
@ -648,7 +636,7 @@ async def _get_rabit_args(worker_map, client: Client):
|
||||
async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
early_stopping_rounds=None, **kwargs):
|
||||
_assert_dask_support()
|
||||
client: Client = _xgb_get_client(client)
|
||||
client: distributed.Client = _xgb_get_client(client)
|
||||
if 'evals_result' in kwargs.keys():
|
||||
raise ValueError(
|
||||
'evals_result is not supported in dask interface.',
|
||||
@ -662,7 +650,7 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(),
|
||||
|
||||
'''
|
||||
LOGGER.info('Training on %s', str(worker_addr))
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
with RabitContext(rabit_args):
|
||||
local_dtrain = _dmatrix_from_worker_map(**dtrain_ref)
|
||||
local_evals = []
|
||||
@ -774,7 +762,7 @@ async def _direct_predict_impl(client, data, predict_fn):
|
||||
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs):
|
||||
async def _predict_async(client, model, data, missing=numpy.nan, **kwargs):
|
||||
if isinstance(model, Booster):
|
||||
booster = model
|
||||
elif isinstance(model, dict):
|
||||
@ -786,7 +774,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
|
||||
type(data)))
|
||||
|
||||
def mapped_predict(partition, is_df):
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
m = DMatrix(partition, missing=missing, nthread=worker.nthreads)
|
||||
predt = booster.predict(m, validate_features=False, **kwargs)
|
||||
@ -813,7 +801,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
|
||||
'''Perform prediction on each worker.'''
|
||||
LOGGER.info('Predicting on %d', worker_id)
|
||||
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(
|
||||
has_margin, worker_map, partition_order, worker)
|
||||
predictions = []
|
||||
@ -832,14 +820,14 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg
|
||||
validate_features=local_part.num_row() != 0,
|
||||
**kwargs)
|
||||
columns = 1 if len(predt.shape) == 1 else predt.shape[1]
|
||||
ret = ((delayed(predt), columns), order)
|
||||
ret = ((dask.delayed(predt), columns), order)
|
||||
predictions.append(ret)
|
||||
return predictions
|
||||
|
||||
def dispatched_get_shape(worker_id):
|
||||
'''Get shape of data in each worker.'''
|
||||
LOGGER.info('Get shape on %d', worker_id)
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
list_of_parts = _get_worker_parts_ordered(
|
||||
False,
|
||||
worker_map,
|
||||
@ -930,7 +918,7 @@ async def _inplace_predict_async(client, model, data,
|
||||
raise TypeError(_expect([da.Array, dd.DataFrame], type(data)))
|
||||
|
||||
def mapped_predict(data, is_df):
|
||||
worker = distributed_get_worker()
|
||||
worker = distributed.get_worker()
|
||||
booster.set_param({'nthread': worker.nthreads})
|
||||
prediction = booster.inplace_predict(
|
||||
data,
|
||||
@ -1072,7 +1060,7 @@ class DaskScikitLearnBase(XGBModel):
|
||||
return self.client.sync(_).__await__()
|
||||
|
||||
@property
|
||||
def client(self) -> Client:
|
||||
def client(self):
|
||||
'''The dask client used in this model.'''
|
||||
client = _xgb_get_client(self._client)
|
||||
return client
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user