From 5f0f8749d90f585ccf0deb61a7ff8ec28cefa7af Mon Sep 17 00:00:00 2001 From: Johan Manders Date: Wed, 4 Nov 2015 18:05:47 +0100 Subject: [PATCH] Cleaned up some code --- python-package/xgboost/core.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 93a73152c..a91019a8c 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -148,20 +148,19 @@ def _maybe_from_pandas(data, label, feature_names, feature_types): if not isinstance(data, pd.DataFrame): return data, label, feature_names, feature_types + mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', + 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', + 'float16': 'float', 'float32': 'float', 'float64': 'float', + 'bool': 'i'} + data_dtypes = data.dtypes - if not all(dtype.name in ('int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'float16', 'float32', 'float64', - 'bool') for dtype in data_dtypes): + if not all(dtype.name in (mapper.keys()) for dtype in data_dtypes): raise ValueError('DataFrame.dtypes for data must be int, float or bool') if label is not None: if isinstance(label, pd.DataFrame): label_dtypes = label.dtypes - if not all(dtype.name in ('int8', 'int16', 'int32', 'int64', - 'uint8', 'uint16', 'uint32', 'uint64', - 'float16', 'float32', 'float64', - 'bool') for dtype in label_dtypes): + if not all(dtype.name in (mapper.keys()) for dtype in label_dtypes): raise ValueError('DataFrame.dtypes for label must be int, float or bool') else: label = label.values.astype('float') @@ -170,10 +169,6 @@ def _maybe_from_pandas(data, label, feature_names, feature_types): feature_names = data.columns.format() if feature_types is None: - mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', - 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', - 'float16': 'float', 'float32': 'float', 'float64': 'float', - 'bool': 'i'} feature_types = [mapper[dtype.name] for dtype in data_dtypes] data = data.values.astype('float')