Fix DMatrix feature names/types IO. (#6507)
* Fix feature names/types IO Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Optional, Any, Union
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Dict, Union, List
|
||||
import ctypes
|
||||
@@ -508,8 +509,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
if feature_names is not None:
|
||||
self.feature_names = feature_names
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "handle") and self.handle:
|
||||
@@ -784,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
return res
|
||||
|
||||
@property
|
||||
def feature_names(self):
|
||||
def feature_names(self) -> List[str]:
|
||||
"""Get feature names (column labels).
|
||||
|
||||
Returns
|
||||
@@ -793,18 +796,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
|
||||
c_str('feature_name'),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
_check_call(
|
||||
_LIB.XGDMatrixGetStrFeatureInfo(
|
||||
self.handle,
|
||||
c_str("feature_name"),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr),
|
||||
)
|
||||
)
|
||||
feature_names = from_cstr_to_pystr(sarr, length)
|
||||
if not feature_names:
|
||||
feature_names = ['f{0}'.format(i)
|
||||
for i in range(self.num_col())]
|
||||
feature_names = ["f{0}".format(i) for i in range(self.num_col())]
|
||||
return feature_names
|
||||
|
||||
@feature_names.setter
|
||||
def feature_names(self, feature_names):
|
||||
def feature_names(self, feature_names: Optional[Union[List[str], str]]) -> None:
|
||||
"""Set feature names (column labels).
|
||||
|
||||
Parameters
|
||||
@@ -828,12 +834,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
msg = 'feature_names must have the same length as data'
|
||||
raise ValueError(msg)
|
||||
# prohibit to use symbols may affect to parse. e.g. []<
|
||||
if not all(isinstance(f, STRING_TYPES) and
|
||||
if not all(isinstance(f, str) and
|
||||
not any(x in f for x in set(('[', ']', '<')))
|
||||
for f in feature_names):
|
||||
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
||||
c_feature_names = [bytes(f, encoding='utf-8')
|
||||
for f in feature_names]
|
||||
c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names]
|
||||
c_feature_names = (ctypes.c_char_p *
|
||||
len(c_feature_names))(*c_feature_names)
|
||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||
@@ -850,7 +855,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.feature_types = None
|
||||
|
||||
@property
|
||||
def feature_types(self):
|
||||
def feature_types(self) -> Optional[List[str]]:
|
||||
"""Get feature types (column types).
|
||||
|
||||
Returns
|
||||
@@ -869,7 +874,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
return res
|
||||
|
||||
@feature_types.setter
|
||||
def feature_types(self, feature_types):
|
||||
def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None:
|
||||
"""Set feature types (column types).
|
||||
|
||||
This is for displaying the results and unrelated
|
||||
@@ -884,7 +889,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if not isinstance(feature_types, (list, str)):
|
||||
raise TypeError(
|
||||
'feature_types must be string or list of strings')
|
||||
if isinstance(feature_types, STRING_TYPES):
|
||||
if isinstance(feature_types, str):
|
||||
# single string will be applied to all columns
|
||||
feature_types = [feature_types] * self.num_col()
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user