Define lazy isinstance for Python compat. (#5364)
* Avoid importing datatable. * Fix #5363.
This commit is contained in:
parent
0fd455e162
commit
a461a9a90a
@ -79,6 +79,14 @@ else:
|
|||||||
# END NUMPY PATHLIB ATTRIBUTION
|
# END NUMPY PATHLIB ATTRIBUTION
|
||||||
###############################################################################
|
###############################################################################
|
||||||
|
|
||||||
|
|
||||||
|
def lazy_isinstance(instance, module, name):
|
||||||
|
'''Use string representation to identify a type.'''
|
||||||
|
module = type(instance).__module__ == module
|
||||||
|
name = type(instance).__name__ == name
|
||||||
|
return module and name
|
||||||
|
|
||||||
|
|
||||||
# pandas
|
# pandas
|
||||||
try:
|
try:
|
||||||
from pandas import DataFrame, Series
|
from pandas import DataFrame, Series
|
||||||
@ -95,27 +103,6 @@ except ImportError:
|
|||||||
pandas_concat = None
|
pandas_concat = None
|
||||||
PANDAS_INSTALLED = False
|
PANDAS_INSTALLED = False
|
||||||
|
|
||||||
# dt
|
|
||||||
try:
|
|
||||||
# Workaround for #4473, compatibility with dask
|
|
||||||
if sys.__stdin__ is not None and sys.__stdin__.closed:
|
|
||||||
sys.__stdin__ = None
|
|
||||||
import datatable
|
|
||||||
|
|
||||||
if hasattr(datatable, "Frame"):
|
|
||||||
DataTable = datatable.Frame
|
|
||||||
else:
|
|
||||||
DataTable = datatable.DataTable
|
|
||||||
DT_INSTALLED = True
|
|
||||||
except ImportError:
|
|
||||||
|
|
||||||
# pylint: disable=too-few-public-methods
|
|
||||||
class DataTable(object):
|
|
||||||
""" dummy for datatable.DataTable """
|
|
||||||
|
|
||||||
DT_INSTALLED = False
|
|
||||||
|
|
||||||
|
|
||||||
# cudf
|
# cudf
|
||||||
try:
|
try:
|
||||||
from cudf import DataFrame as CUDF_DataFrame
|
from cudf import DataFrame as CUDF_DataFrame
|
||||||
|
|||||||
@ -19,9 +19,9 @@ import scipy.sparse
|
|||||||
|
|
||||||
from .compat import (
|
from .compat import (
|
||||||
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
|
STRING_TYPES, DataFrame, MultiIndex, Int64Index, py_str,
|
||||||
PANDAS_INSTALLED, DataTable,
|
PANDAS_INSTALLED, CUDF_INSTALLED,
|
||||||
CUDF_INSTALLED, CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
CUDF_DataFrame, CUDF_Series, CUDF_MultiIndex,
|
||||||
os_fspath, os_PathLike)
|
os_fspath, os_PathLike, lazy_isinstance)
|
||||||
from .libpath import find_lib_path
|
from .libpath import find_lib_path
|
||||||
|
|
||||||
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
# c_bst_ulong corresponds to bst_ulong defined in xgboost/c_api.h
|
||||||
@ -319,7 +319,8 @@ DT_TYPE_MAPPER2 = {'bool': 'i', 'int': 'int', 'real': 'float'}
|
|||||||
def _maybe_dt_data(data, feature_names, feature_types,
|
def _maybe_dt_data(data, feature_names, feature_types,
|
||||||
meta=None, meta_type=None):
|
meta=None, meta_type=None):
|
||||||
"""Validate feature names and types if data table"""
|
"""Validate feature names and types if data table"""
|
||||||
if not isinstance(data, DataTable):
|
if (not lazy_isinstance(data, 'datatable', 'Frame') and
|
||||||
|
not lazy_isinstance(data, 'datatable', 'DataTable')):
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
if meta and data.shape[1] > 1:
|
if meta and data.shape[1] > 1:
|
||||||
@ -470,7 +471,7 @@ class DMatrix(object):
|
|||||||
self._init_from_csc(data)
|
self._init_from_csc(data)
|
||||||
elif isinstance(data, np.ndarray):
|
elif isinstance(data, np.ndarray):
|
||||||
self._init_from_npy2d(data, missing, nthread)
|
self._init_from_npy2d(data, missing, nthread)
|
||||||
elif isinstance(data, DataTable):
|
elif lazy_isinstance(data, 'datatable', 'Frame'):
|
||||||
self._init_from_dt(data, nthread)
|
self._init_from_dt(data, nthread)
|
||||||
elif hasattr(data, "__cuda_array_interface__"):
|
elif hasattr(data, "__cuda_array_interface__"):
|
||||||
self._init_from_array_interface(data, missing, nthread)
|
self._init_from_array_interface(data, missing, nthread)
|
||||||
|
|||||||
@ -36,6 +36,11 @@ def captured_output():
|
|||||||
|
|
||||||
|
|
||||||
class TestBasic(unittest.TestCase):
|
class TestBasic(unittest.TestCase):
|
||||||
|
def test_compat(self):
|
||||||
|
from xgboost.compat import lazy_isinstance
|
||||||
|
a = np.array([1, 2, 3])
|
||||||
|
assert lazy_isinstance(a, 'numpy', 'ndarray')
|
||||||
|
assert not lazy_isinstance(a, 'numpy', 'dataframe')
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
|
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED
|
||||||
from xgboost.compat import CUDF_INSTALLED, DASK_INSTALLED
|
from xgboost.compat import CUDF_INSTALLED, DASK_INSTALLED
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +19,9 @@ def no_pandas():
|
|||||||
|
|
||||||
|
|
||||||
def no_dt():
|
def no_dt():
|
||||||
return {'condition': not DT_INSTALLED,
|
import importlib.util
|
||||||
|
spec = importlib.util.find_spec('datatable')
|
||||||
|
return {'condition': spec is None,
|
||||||
'reason': 'Datatable is not installed.'}
|
'reason': 'Datatable is not installed.'}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user