Restore unknown data support. (#6595)
This commit is contained in:
parent
89a00a5866
commit
d356b7a071
@ -262,25 +262,6 @@ def c_array(ctype, values):
|
|||||||
return (ctype * len(values))(*values)
|
return (ctype * len(values))(*values)
|
||||||
|
|
||||||
|
|
||||||
def _convert_unknown_data(data, meta=None, meta_type=None):
|
|
||||||
if meta is not None:
|
|
||||||
try:
|
|
||||||
data = np.array(data, dtype=meta_type)
|
|
||||||
except Exception as e:
|
|
||||||
raise TypeError('Can not handle data from {}'.format(
|
|
||||||
type(data).__name__)) from e
|
|
||||||
else:
|
|
||||||
warnings.warn(
|
|
||||||
'Unknown data type: ' + str(type(data)) +
|
|
||||||
', coverting it to csr_matrix')
|
|
||||||
try:
|
|
||||||
data = scipy.sparse.csr_matrix(data)
|
|
||||||
except Exception as e:
|
|
||||||
raise TypeError('Can not initialize DMatrix from'
|
|
||||||
' {}'.format(type(data).__name__)) from e
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
class DataIter:
|
class DataIter:
|
||||||
'''The interface for user defined data iterator. Currently is only
|
'''The interface for user defined data iterator. Currently is only
|
||||||
supported by Device DMatrix.
|
supported by Device DMatrix.
|
||||||
@ -542,7 +523,7 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
|||||||
if group is not None:
|
if group is not None:
|
||||||
self.set_group(group)
|
self.set_group(group)
|
||||||
if qid is not None:
|
if qid is not None:
|
||||||
dispatch_meta_backend(matrix=self, data=qid, name='qid')
|
self.set_uint_info('qid', qid)
|
||||||
if label_lower_bound is not None:
|
if label_lower_bound is not None:
|
||||||
self.set_float_info('label_lower_bound', label_lower_bound)
|
self.set_float_info('label_lower_bound', label_lower_bound)
|
||||||
if label_upper_bound is not None:
|
if label_upper_bound is not None:
|
||||||
|
|||||||
@ -517,6 +517,24 @@ def _has_array_protocol(data):
|
|||||||
return hasattr(data, '__array__')
|
return hasattr(data, '__array__')
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_unknown_data(data):
|
||||||
|
warnings.warn(
|
||||||
|
f'Unknown data type: {type(data)}, trying to convert it to csr_matrix',
|
||||||
|
UserWarning
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import scipy
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = scipy.sparse.csr_matrix(data)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
return None
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def dispatch_data_backend(data, missing, threads,
|
def dispatch_data_backend(data, missing, threads,
|
||||||
feature_names, feature_types,
|
feature_names, feature_types,
|
||||||
enable_categorical=False):
|
enable_categorical=False):
|
||||||
@ -570,6 +588,11 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
feature_types)
|
feature_types)
|
||||||
if _has_array_protocol(data):
|
if _has_array_protocol(data):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
converted = _convert_unknown_data(data)
|
||||||
|
if converted:
|
||||||
|
return _from_scipy_csr(data, missing, feature_names, feature_types)
|
||||||
|
|
||||||
raise TypeError('Not supported type for data.' + str(type(data)))
|
raise TypeError('Not supported type for data.' + str(type(data)))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -63,8 +63,9 @@ def get_host_ip(hostIP=None):
|
|||||||
try:
|
try:
|
||||||
hostIP = socket.gethostbyname(socket.getfqdn())
|
hostIP = socket.gethostbyname(socket.getfqdn())
|
||||||
except gaierror:
|
except gaierror:
|
||||||
logging.warning(
|
logging.debug(
|
||||||
'gethostbyname(socket.getfqdn()) failed... trying on hostname()')
|
'gethostbyname(socket.getfqdn()) failed... trying on hostname()'
|
||||||
|
)
|
||||||
hostIP = socket.gethostbyname(socket.gethostname())
|
hostIP = socket.gethostbyname(socket.gethostname())
|
||||||
if hostIP.startswith("127."):
|
if hostIP.startswith("127."):
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
|||||||
@ -296,3 +296,12 @@ class TestDMatrix:
|
|||||||
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
|
param = {'max_depth': 3, 'objective': 'binary:logistic', 'verbosity': 0}
|
||||||
bst = xgb.train(param, dtrain, 5, watchlist)
|
bst = xgb.train(param, dtrain, 5, watchlist)
|
||||||
bst.predict(dtrain)
|
bst.predict(dtrain)
|
||||||
|
|
||||||
|
def test_unknown_data(self):
|
||||||
|
class Data:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
with pytest.warns(UserWarning):
|
||||||
|
d = Data()
|
||||||
|
xgb.DMatrix(d)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user