Modin DF support (#6055)
* Modin DF support * mode change * tests were added, ci env was extended * mode change * Remove redundant installation of modin * Add a pytest skip marker for modin * Install Modin[ray] from PyPI * fix interfering * avoid extra conversion * delete cv test for modin * revert cv function Co-authored-by: ShvetsKS <kirill.shvets@intel.com> Co-authored-by: Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
3a990433f9
commit
c1ca872d1e
@ -151,6 +151,13 @@ def _is_pandas_df(data):
|
||||
return False
|
||||
return isinstance(data, pd.DataFrame)
|
||||
|
||||
def _is_modin_df(data):
|
||||
try:
|
||||
import modin.pandas as pd
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, pd.DataFrame)
|
||||
|
||||
|
||||
_pandas_dtype_mapper = {
|
||||
'int8': 'int',
|
||||
@ -208,8 +215,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None,
|
||||
'DataFrame for {meta} cannot have multiple columns'.format(
|
||||
meta=meta))
|
||||
|
||||
dtype = meta_type if meta_type else 'float'
|
||||
data = data.values.astype(dtype)
|
||||
dtype = meta_type if meta_type else np.float32
|
||||
data = np.ascontiguousarray(data.values, dtype=dtype)
|
||||
|
||||
return data, feature_names, feature_types
|
||||
|
||||
@ -228,6 +235,13 @@ def _is_pandas_series(data):
|
||||
return False
|
||||
return isinstance(data, pd.Series)
|
||||
|
||||
def _is_modin_series(data):
|
||||
try:
|
||||
import modin.pandas as pd
|
||||
except ImportError:
|
||||
return False
|
||||
return isinstance(data, pd.Series)
|
||||
|
||||
|
||||
def _from_pandas_series(data, missing, nthread, feature_types, feature_names):
|
||||
return _from_numpy_array(data.values.astype('float'), missing, nthread,
|
||||
@ -525,6 +539,12 @@ def dispatch_data_backend(data, missing, threads,
|
||||
_warn_unused_missing(data, missing)
|
||||
return _from_dt_df(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
if _is_modin_df(data):
|
||||
return _from_pandas_df(data, missing, threads,
|
||||
feature_names, feature_types)
|
||||
if _is_modin_series(data):
|
||||
return _from_pandas_series(data, missing, threads, feature_names,
|
||||
feature_types)
|
||||
if _has_array_protocol(data):
|
||||
pass
|
||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||
@ -648,6 +668,15 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
|
||||
if _is_dt_df(data):
|
||||
_meta_from_dt(data, name, dtype, handle)
|
||||
return
|
||||
if _is_modin_df(data):
|
||||
data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype)
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
return
|
||||
if _is_modin_series(data):
|
||||
data = data.values.astype('float')
|
||||
assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1
|
||||
_meta_from_numpy(data, name, dtype, handle)
|
||||
return
|
||||
if _has_array_protocol(data):
|
||||
pass
|
||||
raise TypeError('Unsupported type for ' + name, str(type(data)))
|
||||
|
||||
@ -31,3 +31,4 @@ dependencies:
|
||||
- pip:
|
||||
- guzzle_sphinx_theme
|
||||
- datatable
|
||||
- modin[all]
|
||||
|
||||
@ -16,3 +16,4 @@ dependencies:
|
||||
- pip
|
||||
- pip:
|
||||
- cupy-cuda101
|
||||
- modin[all]
|
||||
|
||||
145
tests/python/test_with_modin.py
Normal file
145
tests/python/test_with_modin.py
Normal file
@ -0,0 +1,145 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import testing as tm
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import modin.pandas as md
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_modin())
|
||||
|
||||
|
||||
dpath = 'demo/data/'
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
|
||||
class TestModin(unittest.TestCase):
|
||||
|
||||
def test_modin(self):
|
||||
|
||||
df = md.DataFrame([[1, 2., True], [2, 3., False]],
|
||||
columns=['a', 'b', 'c'])
|
||||
dm = xgb.DMatrix(df, label=md.Series([1, 2]))
|
||||
assert dm.feature_names == ['a', 'b', 'c']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
|
||||
|
||||
# overwrite feature_names and feature_types
|
||||
dm = xgb.DMatrix(df, label=md.Series([1, 2]),
|
||||
feature_names=['x', 'y', 'z'],
|
||||
feature_types=['q', 'q', 'q'])
|
||||
assert dm.feature_names == ['x', 'y', 'z']
|
||||
assert dm.feature_types == ['q', 'q', 'q']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
# incorrect dtypes
|
||||
df = md.DataFrame([[1, 2., 'x'], [2, 3., 'y']],
|
||||
columns=['a', 'b', 'c'])
|
||||
self.assertRaises(ValueError, xgb.DMatrix, df)
|
||||
|
||||
# numeric columns
|
||||
df = md.DataFrame([[1, 2., True], [2, 3., False]])
|
||||
dm = xgb.DMatrix(df, label=md.Series([1, 2]))
|
||||
assert dm.feature_names == ['0', '1', '2']
|
||||
assert dm.feature_types == ['int', 'float', 'i']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
np.testing.assert_array_equal(dm.get_label(), np.array([1, 2]))
|
||||
|
||||
df = md.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6])
|
||||
dm = xgb.DMatrix(df, label=md.Series([1, 2]))
|
||||
assert dm.feature_names == ['4', '5', '6']
|
||||
assert dm.feature_types == ['int', 'float', 'int']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 3
|
||||
|
||||
df = md.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||
dummies = md.get_dummies(df)
|
||||
# B A_X A_Y A_Z
|
||||
# 0 1 1 0 0
|
||||
# 1 2 0 1 0
|
||||
# 2 3 0 0 1
|
||||
result, _, _ = xgb.data._transform_pandas_df(dummies)
|
||||
exp = np.array([[1., 1., 0., 0.],
|
||||
[2., 0., 1., 0.],
|
||||
[3., 0., 0., 1.]])
|
||||
np.testing.assert_array_equal(result, exp)
|
||||
dm = xgb.DMatrix(dummies)
|
||||
assert dm.feature_names == ['B', 'A_X', 'A_Y', 'A_Z']
|
||||
assert dm.feature_types == ['int', 'int', 'int', 'int']
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 4
|
||||
|
||||
df = md.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]})
|
||||
dm = xgb.DMatrix(df)
|
||||
assert dm.feature_names == ['A=1', 'A=2']
|
||||
assert dm.feature_types == ['int', 'int']
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
df_int = md.DataFrame([[1, 1.1], [2, 2.2]], columns=[9, 10])
|
||||
dm_int = xgb.DMatrix(df_int)
|
||||
df_range = md.DataFrame([[1, 1.1], [2, 2.2]], columns=range(9, 11, 1))
|
||||
dm_range = xgb.DMatrix(df_range)
|
||||
assert dm_int.feature_names == ['9', '10'] # assert not "9 "
|
||||
assert dm_int.feature_names == dm_range.feature_names
|
||||
|
||||
# test MultiIndex as columns
|
||||
df = md.DataFrame(
|
||||
[
|
||||
(1, 2, 3, 4, 5, 6),
|
||||
(6, 5, 4, 3, 2, 1)
|
||||
],
|
||||
columns=md.MultiIndex.from_tuples((
|
||||
('a', 1), ('a', 2), ('a', 3),
|
||||
('b', 1), ('b', 2), ('b', 3),
|
||||
))
|
||||
)
|
||||
dm = xgb.DMatrix(df)
|
||||
assert dm.feature_names == ['a 1', 'a 2', 'a 3', 'b 1', 'b 2', 'b 3']
|
||||
assert dm.feature_types == ['int', 'int', 'int', 'int', 'int', 'int']
|
||||
assert dm.num_row() == 2
|
||||
assert dm.num_col() == 6
|
||||
|
||||
def test_modin_label(self):
|
||||
# label must be a single column
|
||||
df = md.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]})
|
||||
self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
|
||||
None, None, 'label', 'float')
|
||||
|
||||
# label must be supported dtype
|
||||
df = md.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)})
|
||||
self.assertRaises(ValueError, xgb.data._transform_pandas_df, df,
|
||||
None, None, 'label', 'float')
|
||||
|
||||
df = md.DataFrame({'A': np.array([1, 2, 3], dtype=int)})
|
||||
result, _, _ = xgb.data._transform_pandas_df(df, None, None,
|
||||
'label', 'float')
|
||||
np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]],
|
||||
dtype=float))
|
||||
dm = xgb.DMatrix(np.random.randn(3, 2), label=df)
|
||||
assert dm.num_row() == 3
|
||||
assert dm.num_col() == 2
|
||||
|
||||
def test_modin_weight(self):
|
||||
kRows = 32
|
||||
kCols = 8
|
||||
|
||||
X = np.random.randn(kRows, kCols)
|
||||
y = np.random.randn(kRows)
|
||||
w = np.random.randn(kRows).astype(np.float32)
|
||||
w_pd = md.DataFrame(w)
|
||||
data = xgb.DMatrix(X, y, w_pd)
|
||||
|
||||
assert data.num_row() == kRows
|
||||
assert data.num_col() == kCols
|
||||
|
||||
np.testing.assert_array_equal(data.get_weight(), w)
|
||||
@ -37,6 +37,15 @@ def no_pandas():
|
||||
'reason': 'Pandas is not installed.'}
|
||||
|
||||
|
||||
def no_modin():
|
||||
reason = 'Modin is not installed.'
|
||||
try:
|
||||
import modin.pandas as _ # noqa
|
||||
return {'condition': False, 'reason': reason}
|
||||
except ImportError:
|
||||
return {'condition': True, 'reason': reason}
|
||||
|
||||
|
||||
def no_dt():
|
||||
import importlib.util
|
||||
spec = importlib.util.find_spec('datatable')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user