Restore unknown data support. (#6595)
This commit is contained in:
parent
89a00a5866
commit
d356b7a071
@ -262,25 +262,6 @@ def c_array(ctype, values):
|
||||
return (ctype * len(values))(*values)
|
||||
|
||||
|
||||
def _convert_unknown_data(data, meta=None, meta_type=None):
|
||||
if meta is not None:
|
||||
try:
|
||||
data = np.array(data, dtype=meta_type)
|
||||
except Exception as e:
|
||||
raise TypeError('Can not handle data from {}'.format(
|
||||
type(data).__name__)) from e
|
||||
else:
|
||||
warnings.warn(
|
||||
'Unknown data type: ' + str(type(data)) +
|
||||
', coverting it to csr_matrix')
|
||||
try:
|
||||
data = scipy.sparse.csr_matrix(data)
|
||||
except Exception as e:
|
||||
raise TypeError('Can not initialize DMatrix from'
|
||||
' {}'.format(type(data).__name__)) from e
|
||||
return data
|
||||
|
||||
|
||||
class DataIter:
|
||||
'''The interface for user defined data iterator. Currently is only
|
||||
supported by Device DMatrix.
|
||||
@ -542,7 +523,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if group is not None:
|
||||
self.set_group(group)
|
||||
if qid is not None:
|
||||
dispatch_meta_backend(matrix=self, data=qid, name='qid')
|
||||
self.set_uint_info('qid', qid)
|
||||
if label_lower_bound is not None:
|
||||
self.set_float_info('label_lower_bound', label_lower_bound)
|
||||
if label_upper_bound is not None:
|
||||
|
||||
@ -517,6 +517,24 @@ def _has_array_protocol(data):
|
||||
return hasattr(data, '__array__')
|
||||
|
||||
|
||||
def _convert_unknown_data(data):
|
||||
warnings.warn(
|
||||
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
|
||||
UserWarning
|
||||
)
|
||||
try:
|
||||
import scipy
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = scipy.sparse.csr_matrix(data)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def dispatch_data_backend(data, missing, threads,
|
||||
feature_names, feature_types,
|
||||
enable_categorical=False):
|
||||
@ -570,6 +588,11 @@ def dispatch_data_backend(data, missing, threads,
|
||||
feature_types)
|
||||
if _has_array_protocol(data):
|
||||
pass
|
||||
|
||||
converted = _convert_unknown_data(data)
|
||||
if converted:
|
||||
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||
|
||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||
|
||||
|
||||
|
||||
@ -63,8 +63,9 @@ def get_host_ip(hostIP=None):
|
||||
try:
|
||||
hostIP = socket.gethostbyname(socket.getfqdn())
|
||||
except gaierror:
|
||||
logging.warning(
|
||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
|
||||
logging.debug(
|
||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
|
||||
)
|
||||
hostIP = socket.gethostbyname(socket.gethostname())
|
||||
if hostIP.startswith("127."):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
@ -296,3 +296,12 @@ class TestDMatrix:
|
||||
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
|
||||
bst = xgb.train(param, dtrain, 5, watchlist)
|
||||
bst.predict(dtrain)
|
||||
|
||||
def test_unknown_data(self):
|
||||
class Data:
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.warns(UserWarning):
|
||||
d = Data()
|
||||
xgb.DMatrix(d)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user