Cleaned up some code

This commit is contained in:
Johan Manders 2015-11-04 18:05:47 +01:00
parent b0f38e9352
commit 5f0f8749d9

View File

@ -148,20 +148,19 @@ def _maybe_from_pandas(data, label, feature_names, feature_types):
if not isinstance(data, pd.DataFrame): if not isinstance(data, pd.DataFrame):
return data, label, feature_names, feature_types return data, label, feature_names, feature_types
mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
data_dtypes = data.dtypes data_dtypes = data.dtypes
if not all(dtype.name in ('int8', 'int16', 'int32', 'int64', if not all(dtype.name in (mapper.keys()) for dtype in data_dtypes):
'uint8', 'uint16', 'uint32', 'uint64',
'float16', 'float32', 'float64',
'bool') for dtype in data_dtypes):
raise ValueError('DataFrame.dtypes for data must be int, float or bool') raise ValueError('DataFrame.dtypes for data must be int, float or bool')
if label is not None: if label is not None:
if isinstance(label, pd.DataFrame): if isinstance(label, pd.DataFrame):
label_dtypes = label.dtypes label_dtypes = label.dtypes
if not all(dtype.name in ('int8', 'int16', 'int32', 'int64', if not all(dtype.name in (mapper.keys()) for dtype in label_dtypes):
'uint8', 'uint16', 'uint32', 'uint64',
'float16', 'float32', 'float64',
'bool') for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool') raise ValueError('DataFrame.dtypes for label must be int, float or bool')
else: else:
label = label.values.astype('float') label = label.values.astype('float')
@ -170,10 +169,6 @@ def _maybe_from_pandas(data, label, feature_names, feature_types):
feature_names = data.columns.format() feature_names = data.columns.format()
if feature_types is None: if feature_types is None:
mapper = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
'bool': 'i'}
feature_types = [mapper[dtype.name] for dtype in data_dtypes] feature_types = [mapper[dtype.name] for dtype in data_dtypes]
data = data.values.astype('float') data = data.values.astype('float')