Fix DMatrix feature names/types IO. (#6507)

* Fix feature names/types IO

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
Jiaming Yuan 2020-12-16 14:24:27 +08:00 committed by GitHub
parent 886486a519
commit ef4a0e0aac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 16 deletions

View File

@ -5,6 +5,7 @@
import collections import collections
# pylint: disable=no-name-in-module,import-error # pylint: disable=no-name-in-module,import-error
from collections.abc import Mapping from collections.abc import Mapping
from typing import List, Optional, Any, Union
# pylint: enable=no-name-in-module,import-error # pylint: enable=no-name-in-module,import-error
from typing import Dict, Union, List from typing import Dict, Union, List
import ctypes import ctypes
@ -508,8 +509,10 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.set_info(label=label, weight=weight, base_margin=base_margin) self.set_info(label=label, weight=weight, base_margin=base_margin)
self.feature_names = feature_names if feature_names is not None:
self.feature_types = feature_types self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
def __del__(self): def __del__(self):
if hasattr(self, "handle") and self.handle: if hasattr(self, "handle") and self.handle:
@ -784,7 +787,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
return res return res
@property @property
def feature_names(self): def feature_names(self) -> List[str]:
"""Get feature names (column labels). """Get feature names (column labels).
Returns Returns
@ -793,18 +796,21 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
""" """
length = c_bst_ulong() length = c_bst_ulong()
sarr = ctypes.POINTER(ctypes.c_char_p)() sarr = ctypes.POINTER(ctypes.c_char_p)()
_check_call(_LIB.XGDMatrixGetStrFeatureInfo(self.handle, _check_call(
c_str('feature_name'), _LIB.XGDMatrixGetStrFeatureInfo(
ctypes.byref(length), self.handle,
ctypes.byref(sarr))) c_str("feature_name"),
ctypes.byref(length),
ctypes.byref(sarr),
)
)
feature_names = from_cstr_to_pystr(sarr, length) feature_names = from_cstr_to_pystr(sarr, length)
if not feature_names: if not feature_names:
feature_names = ['f{0}'.format(i) feature_names = ["f{0}".format(i) for i in range(self.num_col())]
for i in range(self.num_col())]
return feature_names return feature_names
@feature_names.setter @feature_names.setter
def feature_names(self, feature_names): def feature_names(self, feature_names: Optional[Union[List[str], str]]) -> None:
"""Set feature names (column labels). """Set feature names (column labels).
Parameters Parameters
@ -828,12 +834,11 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
msg = 'feature_names must have the same length as data' msg = 'feature_names must have the same length as data'
raise ValueError(msg) raise ValueError(msg)
# prohibit to use symbols may affect to parse. e.g. []< # prohibit to use symbols may affect to parse. e.g. []<
if not all(isinstance(f, STRING_TYPES) and if not all(isinstance(f, str) and
not any(x in f for x in set(('[', ']', '<'))) not any(x in f for x in set(('[', ']', '<')))
for f in feature_names): for f in feature_names):
raise ValueError('feature_names must be string, and may not contain [, ] or <') raise ValueError('feature_names must be string, and may not contain [, ] or <')
c_feature_names = [bytes(f, encoding='utf-8') c_feature_names = [bytes(f, encoding='utf-8') for f in feature_names]
for f in feature_names]
c_feature_names = (ctypes.c_char_p * c_feature_names = (ctypes.c_char_p *
len(c_feature_names))(*c_feature_names) len(c_feature_names))(*c_feature_names)
_check_call(_LIB.XGDMatrixSetStrFeatureInfo( _check_call(_LIB.XGDMatrixSetStrFeatureInfo(
@ -850,7 +855,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
self.feature_types = None self.feature_types = None
@property @property
def feature_types(self): def feature_types(self) -> Optional[List[str]]:
"""Get feature types (column types). """Get feature types (column types).
Returns Returns
@ -869,7 +874,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
return res return res
@feature_types.setter @feature_types.setter
def feature_types(self, feature_types): def feature_types(self, feature_types: Optional[Union[List[Any], Any]]) -> None:
"""Set feature types (column types). """Set feature types (column types).
This is for displaying the results and unrelated This is for displaying the results and unrelated
@ -884,7 +889,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if not isinstance(feature_types, (list, str)): if not isinstance(feature_types, (list, str)):
raise TypeError( raise TypeError(
'feature_types must be string or list of strings') 'feature_types must be string or list of strings')
if isinstance(feature_types, STRING_TYPES): if isinstance(feature_types, str):
# single string will be applied to all columns # single string will be applied to all columns
feature_types = [feature_types] * self.num_col() feature_types = [feature_types] * self.num_col()
try: try:

View File

@ -1,10 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import os
import tempfile
import numpy as np import numpy as np
import xgboost as xgb import xgboost as xgb
import scipy.sparse import scipy.sparse
import pytest import pytest
from scipy.sparse import rand, csr_matrix from scipy.sparse import rand, csr_matrix
import testing as tm
rng = np.random.RandomState(1) rng = np.random.RandomState(1)
dpath = 'demo/data/' dpath = 'demo/data/'
@ -207,6 +211,23 @@ class TestDMatrix:
with pytest.raises(ValueError): with pytest.raises(ValueError):
bst.predict(dm) bst.predict(dm)
@pytest.mark.skipif(**tm.no_pandas())
def test_save_binary(self):
import pandas as pd
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, 'm.dmatrix')
data = pd.DataFrame({
"a": [0, 1],
"b": [2, 3],
"c": [4, 5]
})
m0 = xgb.DMatrix(data.loc[:, ["a", "b"]], data["c"])
assert m0.feature_names == ['a', 'b']
m0.save_binary(path)
m1 = xgb.DMatrix(path)
assert m0.feature_names == m1.feature_names
assert m0.feature_types == m1.feature_types
def test_get_info(self): def test_get_info(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train') dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtrain.get_float_info('label') dtrain.get_float_info('label')