From 3b62e75f2ef39f6794f5f63b4fa43dbc30e64c53 Mon Sep 17 00:00:00 2001 From: wenduowang Date: Mon, 30 Jul 2018 08:36:34 -0600 Subject: [PATCH] Fix bug of using list(x) function when x is string (#3432) * Fix bug of using list(x) function when x is string list('abcdcba') = ['a', 'b', 'c', 'd', 'c', 'b', 'a'] * Allow feature_names/feature_types to be of any type If feature_names/feature_types is iterable, e.g. tuple, list, then convert the value to list, except for string; otherwise construct a list with a single value * Delete excess whitespace * Fix whitespace to pass lint --- python-package/xgboost/core.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 8e282cae9..b54718f43 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -766,8 +766,14 @@ class DMatrix(object): """ if feature_names is not None: # validate feature name - if not isinstance(feature_names, list): - feature_names = list(feature_names) + try: + if not isinstance(feature_names, str): + feature_names = [n for n in iter(feature_names)] + else: + feature_names = [feature_names] + except TypeError: + feature_names = [feature_names] + if len(feature_names) != len(set(feature_names)): raise ValueError('feature_names must be unique') if len(feature_names) != self.num_col(): @@ -796,7 +802,6 @@ class DMatrix(object): Labels for features. None will reset existing feature names """ if feature_types is not None: - if self._feature_names is None: msg = 'Unable to set feature types before setting names' raise ValueError(msg) @@ -805,8 +810,14 @@ class DMatrix(object): # single string will be applied to all columns feature_types = [feature_types] * self.num_col() - if not isinstance(feature_types, list): - feature_types = list(feature_types) + try: + if not isinstance(feature_types, str): + feature_types = [n for n in iter(feature_types)] + else: + feature_types = [feature_types] + except TypeError: + feature_types = [feature_types] + if len(feature_types) != self.num_col(): msg = 'feature_types must have the same length as data' raise ValueError(msg)