Fix DMatrix feature names/types IO. (#6507)
* Fix feature names/types IO Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
886486a519
commit
ef4a0e0aac
@ -5,6 +5,7 @@
|
||||
import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Optional, Any, Union
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Dict, Union, List
|
||||
import ctypes
|
||||
@ -508,8 +509,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
|
||||
self.set_info(label=label, weight=weight, base_margin=base_margin)
|
||||
|
||||
self.feature_names = feature_names
|
||||
self.feature_types = feature_types
|
||||
if feature_names is not None:
|
||||
self.feature_names = feature_names
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "handle") and self.handle:
|
||||
@ -784,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
return res
|
||||
|
||||
@property
|
||||
def feature_names(self):
|
||||
def feature_names(self) -> List[str]:
|
||||
"""Get feature names (column labels).
|
||||
|
||||
Returns
|
||||
@ -793,18 +796,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
"""
|
||||
length = c_bst_ulong()
|
||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
|
||||
c_str('feature_name'),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr)))
|
||||
_check_call(
|
||||
_LIB.XGDMatrixGetStrFeatureInfo(
|
||||
self.handle,
|
||||
c_str("feature_name"),
|
||||
ctypes.byref(length),
|
||||
ctypes.byref(sarr),
|
||||
)
|
||||
)
|
||||
feature_names = from_cstr_to_pystr(sarr, length)
|
||||
if not feature_names:
|
||||
feature_names = ['f{0}'.format(i)
|
||||
for i in range(self.num_col())]
|
||||
feature_names = ["f{0}".format(i) for i in range(self.num_col())]
|
||||
return feature_names
|
||||
|
||||
@feature_names.setter
|
||||
def feature_names(self, feature_names):
|
||||
def feature_names(self, feature_names: Optional[Union[List[str], str]]) -> None:
|
||||
"""Set feature names (column labels).
|
||||
|
||||
Parameters
|
||||
@ -828,12 +834,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
msg = 'feature_names must have the same length as data'
|
||||
raise ValueError(msg)
|
||||
# prohibit to use symbols may affect to parse. e.g. []<
|
||||
if not all(isinstance(f, STRING_TYPES) and
|
||||
if not all(isinstance(f, str) and
|
||||
not any(x in f for x in set(('[', ']', '<')))
|
||||
for f in feature_names):
|
||||
raise ValueError('feature_names must be string, and may not contain [, ] or <')
|
||||
c_feature_names = [bytes(f, encoding='utf-8')
|
||||
for f in feature_names]
|
||||
c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names]
|
||||
c_feature_names = (ctypes.c_char_p *
|
||||
len(c_feature_names))(*c_feature_names)
|
||||
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
|
||||
@ -850,7 +855,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.feature_types = None
|
||||
|
||||
@property
|
||||
def feature_types(self):
|
||||
def feature_types(self) -> Optional[List[str]]:
|
||||
"""Get feature types (column types).
|
||||
|
||||
Returns
|
||||
@ -869,7 +874,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
return res
|
||||
|
||||
@feature_types.setter
|
||||
def feature_types(self, feature_types):
|
||||
def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None:
|
||||
"""Set feature types (column types).
|
||||
|
||||
This is for displaying the results and unrelated
|
||||
@ -884,7 +889,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if not isinstance(feature_types, (list, str)):
|
||||
raise TypeError(
|
||||
'feature_types must be string or list of strings')
|
||||
if isinstance(feature_types, STRING_TYPES):
|
||||
if isinstance(feature_types, str):
|
||||
# single string will be applied to all columns
|
||||
feature_types = [feature_types] * self.num_col()
|
||||
try:
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
import xgboost as xgb
|
||||
import scipy.sparse
|
||||
import pytest
|
||||
from scipy.sparse import rand, csr_matrix
|
||||
|
||||
import testing as tm
|
||||
|
||||
rng = np.random.RandomState(1)
|
||||
|
||||
dpath = 'demo/data/'
|
||||
@ -207,6 +211,23 @@ class TestDMatrix:
|
||||
with pytest.raises(ValueError):
|
||||
bst.predict(dm)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_save_binary(self):
|
||||
import pandas as pd
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = os.path.join(tmpdir, 'm.dmatrix')
|
||||
data = pd.DataFrame({
|
||||
"a": [0, 1],
|
||||
"b": [2, 3],
|
||||
"c": [4, 5]
|
||||
})
|
||||
m0 = xgb.DMatrix(data.loc[:, ["a", "b"]], data["c"])
|
||||
assert m0.feature_names == ['a', 'b']
|
||||
m0.save_binary(path)
|
||||
m1 = xgb.DMatrix(path)
|
||||
assert m0.feature_names == m1.feature_names
|
||||
assert m0.feature_types == m1.feature_types
|
||||
|
||||
def test_get_info(self):
|
||||
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
|
||||
dtrain.get_float_info('label')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user