Restore unknown data support. (#6595)

This commit is contained in:
Jiaming Yuan 2021-01-14 04:51:16 +08:00 committed by GitHub
parent 89a00a5866
commit d356b7a071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 22 deletions

View File

@ -262,25 +262,6 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _convert_unknown_data(data, meta=None, meta_type=None):
if meta is not None:
try:
data = np.array(data, dtype=meta_type)
except Exception as e:
raise TypeError('Can not handle data from {}'.format(
type(data).__name__)) from e
else:
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
try:
data = scipy.sparse.csr_matrix(data)
except Exception as e:
raise TypeError('Can not initialize DMatrix from'
' {}'.format(type(data).__name__)) from e
return data
class DataIter:
'''The interface for user defined data iterator. Currently is only
supported by Device DMatrix.
@ -542,7 +523,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if group is not None:
self.set_group(group)
if qid is not None:
dispatch_meta_backend(matrix=self, data=qid, name='qid')
self.set_uint_info('qid', qid)
if label_lower_bound is not None:
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:

View File

@ -517,6 +517,24 @@ def _has_array_protocol(data):
return hasattr(data, '__array__')
def _convert_unknown_data(data):
warnings.warn(
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
UserWarning
)
try:
import scipy
except ImportError:
return None
try:
data = scipy.sparse.csr_matrix(data)
except Exception: # pylint: disable=broad-except
return None
return data
def dispatch_data_backend(data, missing, threads,
feature_names, feature_types,
enable_categorical=False):
@ -570,6 +588,11 @@ def dispatch_data_backend(data, missing, threads,
feature_types)
if _has_array_protocol(data):
pass
converted = _convert_unknown_data(data)
if converted:
return _from_scipy_csr(data, missing, feature_names, feature_types)
raise TypeError('Not supported type for data.' + str(type(data)))

View File

@ -63,8 +63,9 @@ def get_host_ip(hostIP=None):
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.warning(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
)
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

View File

@ -296,3 +296,12 @@ class TestDMatrix:
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
bst = xgb.train(param, dtrain, 5, watchlist)
bst.predict(dtrain)
def test_unknown_data(self):
class Data:
pass
with pytest.raises(TypeError):
with pytest.warns(UserWarning):
d = Data()
xgb.DMatrix(d)