diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index e4c05dcc2..03d929b4d 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -151,6 +151,13 @@ def _is_pandas_df(data): return False return isinstance(data, pd.DataFrame) +def _is_modin_df(data): + try: + import modin.pandas as pd + except ImportError: + return False + return isinstance(data, pd.DataFrame) + _pandas_dtype_mapper = { 'int8': 'int', @@ -208,8 +215,8 @@ def _transform_pandas_df(data, feature_names=None, feature_types=None, 'DataFrame for {meta} cannot have multiple columns'.format( meta=meta)) - dtype = meta_type if meta_type else 'float' - data = data.values.astype(dtype) + dtype = meta_type if meta_type else np.float32 + data = np.ascontiguousarray(data.values, dtype=dtype) return data, feature_names, feature_types @@ -228,6 +235,13 @@ def _is_pandas_series(data): return False return isinstance(data, pd.Series) +def _is_modin_series(data): + try: + import modin.pandas as pd + except ImportError: + return False + return isinstance(data, pd.Series) + def _from_pandas_series(data, missing, nthread, feature_types, feature_names): return _from_numpy_array(data.values.astype('float'), missing, nthread, @@ -525,6 +539,12 @@ def dispatch_data_backend(data, missing, threads, _warn_unused_missing(data, missing) return _from_dt_df(data, missing, threads, feature_names, feature_types) + if _is_modin_df(data): + return _from_pandas_df(data, missing, threads, + feature_names, feature_types) + if _is_modin_series(data): + return _from_pandas_series(data, missing, threads, feature_names, + feature_types) if _has_array_protocol(data): pass raise TypeError('Not supported type for data.' + str(type(data))) @@ -648,6 +668,15 @@ def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None): if _is_dt_df(data): _meta_from_dt(data, name, dtype, handle) return + if _is_modin_df(data): + data, _, _ = _transform_pandas_df(data, meta=name, meta_type=dtype) + _meta_from_numpy(data, name, dtype, handle) + return + if _is_modin_series(data): + data = data.values.astype('float') + assert len(data.shape) == 1 or data.shape[1] == 0 or data.shape[1] == 1 + _meta_from_numpy(data, name, dtype, handle) + return if _has_array_protocol(data): pass raise TypeError('Unsupported type for ' + name, str(type(data))) diff --git a/tests/ci_build/conda_env/cpu_test.yml b/tests/ci_build/conda_env/cpu_test.yml index 027a85df1..3312121e0 100644 --- a/tests/ci_build/conda_env/cpu_test.yml +++ b/tests/ci_build/conda_env/cpu_test.yml @@ -31,3 +31,4 @@ dependencies: - pip: - guzzle_sphinx_theme - datatable + - modin[all] diff --git a/tests/ci_build/conda_env/win64_test.yml b/tests/ci_build/conda_env/win64_test.yml index ad84aebc3..df06ebff2 100644 --- a/tests/ci_build/conda_env/win64_test.yml +++ b/tests/ci_build/conda_env/win64_test.yml @@ -16,3 +16,4 @@ dependencies: - pip - pip: - cupy-cuda101 + - modin[all] diff --git a/tests/python/test_with_modin.py b/tests/python/test_with_modin.py new file mode 100644 index 000000000..d79672631 --- /dev/null +++ b/tests/python/test_with_modin.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +import numpy as np +import xgboost as xgb +import testing as tm +import unittest +import pytest + +try: + import modin.pandas as md +except ImportError: + pass + + +pytestmark = pytest.mark.skipif(**tm.no_modin()) + + +dpath = 'demo/data/' +rng = np.random.RandomState(1994) + + +class TestModin(unittest.TestCase): + + def test_modin(self): + + df = md.DataFrame([[1, 2., True], [2, 3., False]], + columns=['a', 'b', 'c']) + dm = xgb.DMatrix(df, label=md.Series([1, 2])) + assert dm.feature_names == ['a', 'b', 'c'] + assert dm.feature_types == ['int', 'float', 'i'] + assert dm.num_row() == 2 + assert dm.num_col() == 3 + np.testing.assert_array_equal(dm.get_label(), np.array([1, 2])) + + # overwrite feature_names and feature_types + dm = xgb.DMatrix(df, label=md.Series([1, 2]), + feature_names=['x', 'y', 'z'], + feature_types=['q', 'q', 'q']) + assert dm.feature_names == ['x', 'y', 'z'] + assert dm.feature_types == ['q', 'q', 'q'] + assert dm.num_row() == 2 + assert dm.num_col() == 3 + + # incorrect dtypes + df = md.DataFrame([[1, 2., 'x'], [2, 3., 'y']], + columns=['a', 'b', 'c']) + self.assertRaises(ValueError, xgb.DMatrix, df) + + # numeric columns + df = md.DataFrame([[1, 2., True], [2, 3., False]]) + dm = xgb.DMatrix(df, label=md.Series([1, 2])) + assert dm.feature_names == ['0', '1', '2'] + assert dm.feature_types == ['int', 'float', 'i'] + assert dm.num_row() == 2 + assert dm.num_col() == 3 + np.testing.assert_array_equal(dm.get_label(), np.array([1, 2])) + + df = md.DataFrame([[1, 2., 1], [2, 3., 1]], columns=[4, 5, 6]) + dm = xgb.DMatrix(df, label=md.Series([1, 2])) + assert dm.feature_names == ['4', '5', '6'] + assert dm.feature_types == ['int', 'float', 'int'] + assert dm.num_row() == 2 + assert dm.num_col() == 3 + + df = md.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) + dummies = md.get_dummies(df) + # B A_X A_Y A_Z + # 0 1 1 0 0 + # 1 2 0 1 0 + # 2 3 0 0 1 + result, _, _ = xgb.data._transform_pandas_df(dummies) + exp = np.array([[1., 1., 0., 0.], + [2., 0., 1., 0.], + [3., 0., 0., 1.]]) + np.testing.assert_array_equal(result, exp) + dm = xgb.DMatrix(dummies) + assert dm.feature_names == ['B', 'A_X', 'A_Y', 'A_Z'] + assert dm.feature_types == ['int', 'int', 'int', 'int'] + assert dm.num_row() == 3 + assert dm.num_col() == 4 + + df = md.DataFrame({'A=1': [1, 2, 3], 'A=2': [4, 5, 6]}) + dm = xgb.DMatrix(df) + assert dm.feature_names == ['A=1', 'A=2'] + assert dm.feature_types == ['int', 'int'] + assert dm.num_row() == 3 + assert dm.num_col() == 2 + + df_int = md.DataFrame([[1, 1.1], [2, 2.2]], columns=[9, 10]) + dm_int = xgb.DMatrix(df_int) + df_range = md.DataFrame([[1, 1.1], [2, 2.2]], columns=range(9, 11, 1)) + dm_range = xgb.DMatrix(df_range) + assert dm_int.feature_names == ['9', '10'] # assert not "9 " + assert dm_int.feature_names == dm_range.feature_names + + # test MultiIndex as columns + df = md.DataFrame( + [ + (1, 2, 3, 4, 5, 6), + (6, 5, 4, 3, 2, 1) + ], + columns=md.MultiIndex.from_tuples(( + ('a', 1), ('a', 2), ('a', 3), + ('b', 1), ('b', 2), ('b', 3), + )) + ) + dm = xgb.DMatrix(df) + assert dm.feature_names == ['a 1', 'a 2', 'a 3', 'b 1', 'b 2', 'b 3'] + assert dm.feature_types == ['int', 'int', 'int', 'int', 'int', 'int'] + assert dm.num_row() == 2 + assert dm.num_col() == 6 + + def test_modin_label(self): + # label must be a single column + df = md.DataFrame({'A': ['X', 'Y', 'Z'], 'B': [1, 2, 3]}) + self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, + None, None, 'label', 'float') + + # label must be supported dtype + df = md.DataFrame({'A': np.array(['a', 'b', 'c'], dtype=object)}) + self.assertRaises(ValueError, xgb.data._transform_pandas_df, df, + None, None, 'label', 'float') + + df = md.DataFrame({'A': np.array([1, 2, 3], dtype=int)}) + result, _, _ = xgb.data._transform_pandas_df(df, None, None, + 'label', 'float') + np.testing.assert_array_equal(result, np.array([[1.], [2.], [3.]], + dtype=float)) + dm = xgb.DMatrix(np.random.randn(3, 2), label=df) + assert dm.num_row() == 3 + assert dm.num_col() == 2 + + def test_modin_weight(self): + kRows = 32 + kCols = 8 + + X = np.random.randn(kRows, kCols) + y = np.random.randn(kRows) + w = np.random.randn(kRows).astype(np.float32) + w_pd = md.DataFrame(w) + data = xgb.DMatrix(X, y, w_pd) + + assert data.num_row() == kRows + assert data.num_col() == kCols + + np.testing.assert_array_equal(data.get_weight(), w) diff --git a/tests/python/testing.py b/tests/python/testing.py index 0c462518f..f6a05a5d7 100644 --- a/tests/python/testing.py +++ b/tests/python/testing.py @@ -37,6 +37,15 @@ def no_pandas(): 'reason': 'Pandas is not installed.'} +def no_modin(): + reason = 'Modin is not installed.' + try: + import modin.pandas as _ # noqa + return {'condition': False, 'reason': reason} + except ImportError: + return {'condition': True, 'reason': reason} + + def no_dt(): import importlib.util spec = importlib.util.find_spec('datatable')