Cleanup pandas support

This commit is contained in:
sinhrks
2015-11-13 05:54:41 +09:00
parent 4fb6153eed
commit 25c4fbd0cb
5 changed files with 131 additions and 56 deletions

View File

@@ -0,0 +1,47 @@
# coding: utf-8
# pylint: disable=unused-import, invalid-name
"""For compatibility"""
from __future__ import absolute_import
import sys
PY3 = (sys.version_info[0] == 3)
if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,
# pandas
try:
from pandas import DataFrame
PANDAS_INSTALLED = True
except ImportError:
class DataFrame(object):
""" dummy for pandas.DataFrame """
pass
PANDAS_INSTALLED = False
# sklearn
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
SKLEARN_INSTALLED = True
XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin
except ImportError:
SKLEARN_INSTALLED = False
# used for compatiblity without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object

View File

@@ -4,7 +4,6 @@
from __future__ import absolute_import
import os
import sys
import ctypes
import collections
@@ -13,20 +12,12 @@ import scipy.sparse
from .libpath import find_lib_path
from .compat import STRING_TYPES, PY3, DataFrame
class XGBoostError(Exception):
"""Error throwed by xgboost trainer."""
pass
PY3 = (sys.version_info[0] == 3)
if PY3:
# pylint: disable=invalid-name, redefined-builtin
STRING_TYPES = str,
else:
# pylint: disable=invalid-name
STRING_TYPES = basestring,
def from_pystr_to_cstr(data):
"""Convert a list of Python str to C pointer
@@ -138,42 +129,49 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _maybe_from_pandas(data, label, feature_names, feature_types):
""" Extract internal data from pd.DataFrame """
try:
import pandas as pd
except ImportError:
return data, label, feature_names, feature_types
if not isinstance(data, pd.DataFrame):
return data, label, feature_names, feature_types
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
def _maybe_pandas_data(data, feature_names, feature_types):
""" Extract internal data from pd.DataFrame for DMatrix data """
if not isinstance(data, DataFrame):
return data, feature_names, feature_types
data_dtypes = data.dtypes
if not all(dtype.name in (mapper.keys()) for dtype in data_dtypes):
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in data_dtypes):
raise ValueError('DataFrame.dtypes for data must be int, float or bool')
if label is not None:
if isinstance(label, pd.DataFrame):
label_dtypes = label.dtypes
if not all(dtype.name in (mapper.keys()) for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
else:
label = label.values.astype('float')
if feature_names is None:
feature_names = data.columns.format()
if feature_types is None:
feature_types = [mapper[dtype.name] for dtype in data_dtypes]
feature_types = [PANDAS_DTYPE_MAPPER[dtype.name] for dtype in data_dtypes]
data = data.values.astype('float')
return data, label, feature_names, feature_types
return data, feature_names, feature_types
def _maybe_pandas_label(label):
""" Extract internal data from pd.DataFrame for DMatrix label """
if isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
else:
label = label.values.astype('float')
# pd.Series can be passed to xgb as it is
return label
class DMatrix(object):
"""Data Matrix used in XGBoost.
@@ -216,13 +214,10 @@ class DMatrix(object):
self.handle = None
return
klass = getattr(getattr(data, '__class__', None), '__name__', None)
if klass == 'DataFrame':
# once check class name to avoid unnecessary pandas import
data, label, feature_names, feature_types = _maybe_from_pandas(data,
label,
feature_names,
feature_types)
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types)
label = _maybe_pandas_label(label)
if isinstance(data, STRING_TYPES):
self.handle = ctypes.c_void_p()

View File

@@ -7,23 +7,9 @@ import numpy as np
from .core import Booster, DMatrix, XGBoostError
from .training import train
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
SKLEARN_INSTALLED = True
except ImportError:
SKLEARN_INSTALLED = False
from .compat import (SKLEARN_INSTALLED, XGBModelBase,
XGBClassifierBase, XGBRegressorBase, LabelEncoder)
# used for compatiblity without sklearn
XGBModelBase = object
XGBClassifierBase = object
XGBRegressorBase = object
if SKLEARN_INSTALLED:
XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin
class XGBModel(XGBModelBase):
# pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name