Restore unknown data support. (#6595)

This commit is contained in:
Jiaming Yuan
2021-01-14 04:51:16 +08:00
committed by GitHub
parent 89a00a5866
commit d356b7a071
4 changed files with 36 additions and 22 deletions

View File

@@ -262,25 +262,6 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _convert_unknown_data(data, meta=None, meta_type=None):
if meta is not None:
try:
data = np.array(data, dtype=meta_type)
except Exception as e:
raise TypeError('Can not handle data from {}'.format(
type(data).__name__)) from e
else:
warnings.warn(
'Unknown data type: ' + str(type(data)) +
', coverting it to csr_matrix')
try:
data = scipy.sparse.csr_matrix(data)
except Exception as e:
raise TypeError('Can not initialize DMatrix from'
' {}'.format(type(data).__name__)) from e
return data
class DataIter:
'''The interface for user defined data iterator. Currently is only
supported by Device DMatrix.
@@ -542,7 +523,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
if group is not None:
self.set_group(group)
if qid is not None:
dispatch_meta_backend(matrix=self, data=qid, name='qid')
self.set_uint_info('qid', qid)
if label_lower_bound is not None:
self.set_float_info('label_lower_bound', label_lower_bound)
if label_upper_bound is not None:

View File

@@ -517,6 +517,24 @@ def _has_array_protocol(data):
return hasattr(data, '__array__')
def _convert_unknown_data(data):
warnings.warn(
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
UserWarning
)
try:
import scipy
except ImportError:
return None
try:
data = scipy.sparse.csr_matrix(data)
except Exception: # pylint: disable=broad-except
return None
return data
def dispatch_data_backend(data, missing, threads,
feature_names, feature_types,
enable_categorical=False):
@@ -570,6 +588,11 @@ def dispatch_data_backend(data, missing, threads,
feature_types)
if _has_array_protocol(data):
pass
converted = _convert_unknown_data(data)
if converted:
return _from_scipy_csr(data, missing, feature_names, feature_types)
raise TypeError('Not supported type for data.' + str(type(data)))

View File

@@ -63,8 +63,9 @@ def get_host_ip(hostIP=None):
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logging.warning(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
logging.debug(
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
)
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)