Fix DMatrix feature names/types IO. (#6507)
* Fix feature names/types IO Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
886486a519
commit
ef4a0e0aac
@ -5,6 +5,7 @@
|
|||||||
import collections
|
import collections
|
||||||
# pylint: disable=no-name-in-module,import-error
|
# pylint: disable=no-name-in-module,import-error
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from typing import List, Optional, Any, Union
|
||||||
# pylint: enable=no-name-in-module,import-error
|
# pylint: enable=no-name-in-module,import-error
|
||||||
from typing import Dict, Union, List
|
from typing import Dict, Union, List
|
||||||
import ctypes
|
import ctypes
|
||||||
@ -508,8 +509,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
|
|
||||||
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
||||||
|
|
||||||
self.feature_names = feature_names
|
if feature_names is not None:
|
||||||
self.feature_types = feature_types
|
self.feature_names = feature_names
|
||||||
|
if feature_types is not None:
|
||||||
|
self.feature_types = feature_types
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if hasattr(self, "handle") and self.handle:
|
if hasattr(self, "handle") and self.handle:
|
||||||
@ -784,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feature_names(self):
|
def feature_names(self) -> List[str]:
|
||||||
"""Get feature names (column labels).
|
"""Get feature names (column labels).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -793,18 +796,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
"""
|
"""
|
||||||
length = c_bst_ulong()
|
length = c_bst_ulong()
|
||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
|
_check_call(
|
||||||
c_str('feature_name'),
|
_LIB.XGDMatrixGetStrFeatureInfo(
|
||||||
ctypes.byref(length),
|
self.handle,
|
||||||
ctypes.byref(sarr)))
|
c_str("feature_name"),
|
||||||
|
ctypes.byref(length),
|
||||||
|
ctypes.byref(sarr),
|
||||||
|
)
|
||||||
|
)
|
||||||
feature_names = from_cstr_to_pystr(sarr, length)
|
feature_names = from_cstr_to_pystr(sarr, length)
|
||||||
if not feature_names:
|
if not feature_names:
|
||||||
feature_names = ['f{0}'.format(i)
|
feature_names = ["f{0}".format(i) for i in range(self.num_col())]
|
||||||
for i in range(self.num_col())]
|
|
||||||
return feature_names
|
return feature_names
|
||||||
|
|
||||||
@feature_names.setter
|
@feature_names.setter
|
||||||
def feature_names(self, feature_names):
|
def feature_names(self, feature_names: Optional[Union[List[str], str]]) -> None:
|
||||||
"""Set feature names (column labels).
|
"""Set feature names (column labels).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -828,12 +834,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
msg = 'feature_names must have the same length as data'
|
msg = 'feature_names must have the same length as data'
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
# prohibit to use symbols may affect to parse. e.g. []<
|
# prohibit to use symbols may affect to parse. e.g. []<
|
||||||
if not all(isinstance(f, STRING_TYPES) and
|
if not all(isinstance(f, str) and
|
||||||
not any(x in f for x in set(('[', ']', '<')))
|
not any(x in f for x in set(('[', ']', '<')))
|
||||||
for f in feature_names):
|
for f in feature_names):
|
||||||
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
||||||
c_feature_names = [bytes(f, encoding='utf-8')
|
c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names]
|
||||||
for f in feature_names]
|
|
||||||
c_feature_names = (ctypes.c_char_p *
|
c_feature_names = (ctypes.c_char_p *
|
||||||
len(c_feature_names))(*c_feature_names)
|
len(c_feature_names))(*c_feature_names)
|
||||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||||
@ -850,7 +855,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
self.feature_types = None
|
self.feature_types = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def feature_types(self):
|
def feature_types(self) -> Optional[List[str]]:
|
||||||
"""Get feature types (column types).
|
"""Get feature types (column types).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -869,7 +874,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
@feature_types.setter
|
@feature_types.setter
|
||||||
def feature_types(self, feature_types):
|
def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None:
|
||||||
"""Set feature types (column types).
|
"""Set feature types (column types).
|
||||||
|
|
||||||
This is for displaying the results and unrelated
|
This is for displaying the results and unrelated
|
||||||
@ -884,7 +889,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
if not isinstance(feature_types, (list, str)):
|
if not isinstance(feature_types, (list, str)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
'feature_types must be string or list of strings')
|
'feature_types must be string or list of strings')
|
||||||
if isinstance(feature_types, STRING_TYPES):
|
if isinstance(feature_types, str):
|
||||||
# single string will be applied to all columns
|
# single string will be applied to all columns
|
||||||
feature_types = [feature_types] * self.num_col()
|
feature_types = [feature_types] * self.num_col()
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -1,10 +1,14 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
import pytest
|
import pytest
|
||||||
from scipy.sparse import rand, csr_matrix
|
from scipy.sparse import rand, csr_matrix
|
||||||
|
|
||||||
|
import testing as tm
|
||||||
|
|
||||||
rng = np.random.RandomState(1)
|
rng = np.random.RandomState(1)
|
||||||
|
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
@ -207,6 +211,23 @@ class TestDMatrix:
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
bst.predict(dm)
|
bst.predict(dm)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_pandas())
|
||||||
|
def test_save_binary(self):
|
||||||
|
import pandas as pd
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
path = os.path.join(tmpdir, 'm.dmatrix')
|
||||||
|
data = pd.DataFrame({
|
||||||
|
"a": [0, 1],
|
||||||
|
"b": [2, 3],
|
||||||
|
"c": [4, 5]
|
||||||
|
})
|
||||||
|
m0 = xgb.DMatrix(data.loc[:, ["a", "b"]], data["c"])
|
||||||
|
assert m0.feature_names == ['a', 'b']
|
||||||
|
m0.save_binary(path)
|
||||||
|
m1 = xgb.DMatrix(path)
|
||||||
|
assert m0.feature_names == m1.feature_names
|
||||||
|
assert m0.feature_types == m1.feature_types
|
||||||
|
|
||||||
def test_get_info(self):
|
def test_get_info(self):
|
||||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||||
dtrain.get_float_info('label')
|
dtrain.get_float_info('label')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user