Fix DMatrix feature names/types IO. (#6507)

* Fix feature names/types IO

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-12-16 14:24:27 +08:00 committed by GitHub
parent 886486a519
commit ef4a0e0aac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 16 deletions

View File

@ -5,6 +5,7 @@
import collections
# pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping
from typing import List, Optional, Any, Union
# pylint: enable=no-name-in-module,import-error
from typing import Dict, Union, List
import ctypes
@ -508,8 +509,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.set_info(label=label, weight=weight, base_margin=base_margin)
self.feature_names = feature_names
self.feature_types = feature_types
if feature_names is not None:
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
def __del__(self):
if hasattr(self, "handle") and self.handle:
@ -784,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
return res
@property
def feature_names(self):
def feature_names(self) -> List[str]:
"""Get feature names (column labels).
Returns
@ -793,18 +796,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
"""
length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)()
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle,
c_str('feature_name'),
ctypes.byref(length),
ctypes.byref(sarr)))
_check_call(
_LIB.XGDMatrixGetStrFeatureInfo(
self.handle,
c_str("feature_name"),
ctypes.byref(length),
ctypes.byref(sarr),
)
)
feature_names = from_cstr_to_pystr(sarr, length)
if not feature_names:
feature_names = ['f{0}'.format(i)
for i in range(self.num_col())]
feature_names = ["f{0}".format(i) for i in range(self.num_col())]
return feature_names
@feature_names.setter
def feature_names(self, feature_names):
def feature_names(self, feature_names: Optional[Union[List[str], str]]) -> None:
"""Set feature names (column labels).
Parameters
@ -828,12 +834,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
msg = 'feature_names must have the same length as data'
raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []<
if not all(isinstance(f, STRING_TYPES) and
if not all(isinstance(f, str) and
not any(x in f for x in set(('[', ']', '<')))
for f in feature_names):
raise ValueError('feature_names must be string, and may not contain [, ] or <')
c_feature_names = [bytes(f, encoding='utf-8')
for f in feature_names]
c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names]
c_feature_names = (ctypes.c_char_p *
len(c_feature_names))(*c_feature_names)
_check_call(_LIB.XGDMatrixSetStrFeatureInfo(
@ -850,7 +855,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.feature_types = None
@property
def feature_types(self):
def feature_types(self) -> Optional[List[str]]:
"""Get feature types (column types).
Returns
@ -869,7 +874,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
return res
@feature_types.setter
def feature_types(self, feature_types):
def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None:
"""Set feature types (column types).
This is for displaying the results and unrelated
@ -884,7 +889,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if not isinstance(feature_types, (list, str)):
raise TypeError(
'feature_types must be string or list of strings')
if isinstance(feature_types, STRING_TYPES):
if isinstance(feature_types, str):
# single string will be applied to all columns
feature_types = [feature_types] * self.num_col()
try:

View File

@ -1,10 +1,14 @@
# -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np
import xgboost as xgb
import scipy.sparse
import pytest
from scipy.sparse import rand, csr_matrix
import testing as tm
rng = np.random.RandomState(1)
dpath = 'demo/data/'
@ -207,6 +211,23 @@ class TestDMatrix:
with pytest.raises(ValueError):
bst.predict(dm)
@pytest.mark.skipif(**tm.no_pandas())
def test_save_binary(self):
import pandas as pd
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'm.dmatrix')
data = pd.DataFrame({
"a": [0, 1],
"b": [2, 3],
"c": [4, 5]
})
m0 = xgb.DMatrix(data.loc[:, ["a", "b"]], data["c"])
assert m0.feature_names == ['a', 'b']
m0.save_binary(path)
m1 = xgb.DMatrix(path)
assert m0.feature_names == m1.feature_names
assert m0.feature_types == m1.feature_types
def test_get_info(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtrain.get_float_info('label')