diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 30a51ec68..edde7470c 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -1,11 +1,10 @@ # coding: utf-8 # pylint: disable= invalid-name, unused-import """For compatibility and optional dependencies.""" -import abc -import os import sys -from pathlib import PurePath - +import types +import importlib.util +import logging import numpy as np assert (sys.version_info[0] == 3), 'Python 2 is no longer supported.' @@ -120,3 +119,60 @@ except ImportError: sparse = False scipy_sparse = False SCIPY_INSTALLED = False + + +# Modified from tensorflow with added caching. There's a `LazyLoader` in +# `importlib.utils`, except it's unclear from its document on how to use it. This one +# seems to be easy to understand and works out of box. + +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +# file except in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the specific language governing +# permissions and limitations under the License. +class LazyLoader(types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies. + """ + + def __init__(self, local_name, parent_module_globals, name, warning=None): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + self._warning = warning + self.module = None + + super().__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + + # Emit a warning if one was specified + if self._warning: + logging.warning(self._warning) + # Make sure to only warn once. + self._warning = None + + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on lookups + # that fail). + self.__dict__.update(module.__dict__) + + return module + + def __getattr__(self, item): + if not self.module: + self.module = self._load() + return getattr(self.module, item) + + def __dir__(self): + if not self.module: + self.module = self._load() + return dir(self.module) diff --git a/python-package/xgboost/dask.py b/python-package/xgboost/dask.py index ed3aaae5e..7ae24f9d3 100644 --- a/python-package/xgboost/dask.py +++ b/python-package/xgboost/dask.py @@ -23,7 +23,7 @@ import numpy from . import rabit -from .compat import DASK_INSTALLED +from .compat import LazyLoader from .compat import sparse, scipy_sparse from .compat import PANDAS_INSTALLED, DataFrame, Series, pandas_concat from .compat import CUDF_concat @@ -35,23 +35,11 @@ from .tracker import RabitTracker from .sklearn import XGBModel, XGBRegressorBase, XGBClassifierBase from .sklearn import xgboost_model_doc -try: - from dask.distributed import Client, get_client - from dask.distributed import comm as distributed_comm - from dask.distributed import wait as distributed_wait - from dask.distributed import get_worker as distributed_get_worker - from dask import dataframe as dd - from dask import array as da - from dask import delayed -except ImportError: - Client = None - get_client = None - distributed_comm = None - distributed_wait = None - distributed_get_worker = None - dd = None - da = None - delayed = None + +dd = LazyLoader('dd', globals(), 'dask.dataframe') +da = LazyLoader('da', globals(), 'dask.array') +dask = LazyLoader('dask', globals(), 'dask') +distributed = LazyLoader('distributed', globals(), 'dask.distributed') # Current status is considered as initial support, many features are # not properly supported yet. @@ -91,12 +79,12 @@ def _start_tracker(host, n_workers): def _assert_dask_support(): - if not DASK_INSTALLED: + try: + import dask # pylint: disable=W0621,W0611 + except ImportError as e: raise ImportError( - 'Dask needs to be installed in order to use this module') - if not distributed_wait: - raise ImportError( - 'distributed needs to be installed in order to use this module.') + 'Dask needs to be installed in order to use this module') from e + if platform.system() == 'Windows': msg = 'Windows is not officially supported for dask/xgboost,' msg += ' contribution are welcomed.' @@ -107,7 +95,7 @@ class RabitContext: '''A context controling rabit initialization and finalization.''' def __init__(self, args): self.args = args - worker = distributed_get_worker() + worker = distributed.get_worker() self.args.append( ('DMLC_TASK_ID=[xgboost.dask]:' + str(worker.address)).encode()) @@ -146,10 +134,10 @@ def concat(value): # pylint: disable=too-many-return-statements def _xgb_get_client(client): '''Simple wrapper around testing None.''' - if not isinstance(client, (type(get_client()), type(None))): + if not isinstance(client, (type(distributed.get_client()), type(None))): raise TypeError( - _expect([type(get_client()), type(None)], type(client))) - ret = get_client() if client is None else client + _expect([type(distributed.get_client()), type(None)], type(client))) + ret = distributed.get_client() if client is None else client return ret @@ -217,7 +205,7 @@ class DaskDMatrix: feature_names=None, feature_types=None): _assert_dask_support() - client: Client = _xgb_get_client(client) + client: distributed.Client = _xgb_get_client(client) self.feature_names = feature_names self.feature_types = feature_types @@ -313,10 +301,10 @@ class DaskDMatrix: append_meta(ll_parts, 'label_lower_bound') append_meta(lu_parts, 'label_upper_bound') - parts = list(map(delayed, zip(*parts))) + parts = list(map(dask.delayed, zip(*parts))) parts = client.compute(parts) - await distributed_wait(parts) # async wait for parts to be computed + await distributed.wait(parts) # async wait for parts to be computed for part in parts: assert part.status == 'finished' @@ -354,7 +342,7 @@ class DaskDMatrix: def _get_worker_parts_ordered(has_base_margin, worker_map, partition_order, worker): list_of_parts = worker_map[worker.address] - client = get_client() + client = distributed.get_client() list_of_parts_value = client.gather(list_of_parts) result = [] @@ -378,7 +366,7 @@ def _get_worker_parts(worker_map, meta_names, worker): # `_get_worker_parts` is launched inside worker. In dask side # this should be equal to `worker._get_client`. - client = get_client() + client = distributed.get_client() list_of_parts = client.gather(list_of_parts) data = None labels = None @@ -544,7 +532,7 @@ class DaskDeviceQuantileDMatrix(DaskDMatrix): def _create_device_quantile_dmatrix(feature_names, feature_types, meta_names, missing, worker_map, max_bin): - worker = distributed_get_worker() + worker = distributed.get_worker() if worker.address not in set(worker_map.keys()): msg = 'worker {address} has an empty DMatrix. ' \ 'All workers associated with this DMatrix: {workers}'.format( @@ -584,7 +572,7 @@ def _create_dmatrix(feature_names, feature_types, meta_names, missing, A DMatrix object. ''' - worker = distributed_get_worker() + worker = distributed.get_worker() if worker.address not in set(worker_map.keys()): msg = 'worker {address} has an empty DMatrix. ' \ 'All workers associated with this DMatrix: {workers}'.format( @@ -630,9 +618,9 @@ def _dmatrix_from_worker_map(is_quantile, **kwargs): return _create_dmatrix(**kwargs) -async def _get_rabit_args(worker_map, client: Client): +async def _get_rabit_args(worker_map, client): '''Get rabit context arguments from data distribution in DaskDMatrix.''' - host = distributed_comm.get_address_host(client.scheduler.address) + host = distributed.comm.get_address_host(client.scheduler.address) env = await client.run_on_scheduler( _start_tracker, host.strip('/:'), len(worker_map)) rabit_args = [('%s=%s' % item).encode() for item in env.items()] @@ -648,7 +636,7 @@ async def _get_rabit_args(worker_map, client: Client): async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(), early_stopping_rounds=None, **kwargs): _assert_dask_support() - client: Client = _xgb_get_client(client) + client: distributed.Client = _xgb_get_client(client) if 'evals_result' in kwargs.keys(): raise ValueError( 'evals_result is not supported in dask interface.', @@ -662,7 +650,7 @@ async def _train_async(client, params, dtrain: DaskDMatrix, *args, evals=(), ''' LOGGER.info('Training on %s', str(worker_addr)) - worker = distributed_get_worker() + worker = distributed.get_worker() with RabitContext(rabit_args): local_dtrain = _dmatrix_from_worker_map(**dtrain_ref) local_evals = [] @@ -774,7 +762,7 @@ async def _direct_predict_impl(client, data, predict_fn): # pylint: disable=too-many-statements -async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwargs): +async def _predict_async(client, model, data, missing=numpy.nan, **kwargs): if isinstance(model, Booster): booster = model elif isinstance(model, dict): @@ -786,7 +774,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg type(data))) def mapped_predict(partition, is_df): - worker = distributed_get_worker() + worker = distributed.get_worker() booster.set_param({'nthread': worker.nthreads}) m = DMatrix(partition, missing=missing, nthread=worker.nthreads) predt = booster.predict(m, validate_features=False, **kwargs) @@ -813,7 +801,7 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg '''Perform prediction on each worker.''' LOGGER.info('Predicting on %d', worker_id) - worker = distributed_get_worker() + worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered( has_margin, worker_map, partition_order, worker) predictions = [] @@ -832,14 +820,14 @@ async def _predict_async(client: Client, model, data, missing=numpy.nan, **kwarg validate_features=local_part.num_row() != 0, **kwargs) columns = 1 if len(predt.shape) == 1 else predt.shape[1] - ret = ((delayed(predt), columns), order) + ret = ((dask.delayed(predt), columns), order) predictions.append(ret) return predictions def dispatched_get_shape(worker_id): '''Get shape of data in each worker.''' LOGGER.info('Get shape on %d', worker_id) - worker = distributed_get_worker() + worker = distributed.get_worker() list_of_parts = _get_worker_parts_ordered( False, worker_map, @@ -930,7 +918,7 @@ async def _inplace_predict_async(client, model, data, raise TypeError(_expect([da.Array, dd.DataFrame], type(data))) def mapped_predict(data, is_df): - worker = distributed_get_worker() + worker = distributed.get_worker() booster.set_param({'nthread': worker.nthreads}) prediction = booster.inplace_predict( data, @@ -1072,7 +1060,7 @@ class DaskScikitLearnBase(XGBModel): return self.client.sync(_).__await__() @property - def client(self) -> Client: + def client(self): '''The dask client used in this model.''' client = _xgb_get_client(self._client) return client