Export Python Interface for external memory. (#7070)
* Add Python iterator interface. * Add tests. * Add demo. * Add documents. * Handle empty dataset.
This commit is contained in:
@@ -6,7 +6,7 @@ Contributors: https://github.com/dmlc/xgboost/blob/master/CONTRIBUTORS.md
|
||||
|
||||
import os
|
||||
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster
|
||||
from .core import DMatrix, DeviceQuantileDMatrix, Booster, DataIter
|
||||
from .training import train, cv
|
||||
from . import rabit # noqa
|
||||
from . import tracker # noqa
|
||||
@@ -25,7 +25,7 @@ VERSION_FILE = os.path.join(os.path.dirname(__file__), 'VERSION')
|
||||
with open(VERSION_FILE) as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster',
|
||||
__all__ = ['DMatrix', 'DeviceQuantileDMatrix', 'Booster', 'DataIter',
|
||||
'train', 'cv',
|
||||
'RabitTracker',
|
||||
'XGBModel', 'XGBClassifier', 'XGBRegressor', 'XGBRanker',
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import collections
|
||||
# pylint: disable=no-name-in-module,import-error
|
||||
from collections.abc import Mapping
|
||||
from typing import List, Optional, Any, Union, Dict
|
||||
from typing import List, Optional, Any, Union, Dict, TypeVar
|
||||
# pylint: enable=no-name-in-module,import-error
|
||||
from typing import Callable, Tuple
|
||||
import ctypes
|
||||
@@ -313,78 +313,130 @@ def _prediction_output(shape, dims, predts, is_cuda):
|
||||
return arr_predict
|
||||
|
||||
|
||||
class DataIter:
|
||||
'''The interface for user defined data iterator. Currently is only supported by Device
|
||||
DMatrix.
|
||||
class DataIter: # pylint: disable=too-many-instance-attributes
|
||||
"""The interface for user defined data iterator.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache_prefix:
|
||||
Prefix to the cache files, only used in external memory. It can be either an URI
|
||||
or a file path.
|
||||
|
||||
"""
|
||||
_T = TypeVar("_T")
|
||||
|
||||
def __init__(self, cache_prefix: Optional[str] = None) -> None:
|
||||
self.cache_prefix = cache_prefix
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
self._handle = _ProxyDMatrix()
|
||||
self.exception = None
|
||||
self.enable_categorical = False
|
||||
self._allow_host = False
|
||||
self._exception: Optional[Exception] = None
|
||||
self._enable_categorical = False
|
||||
self._allow_host = True
|
||||
# Stage data in Python until reset or next is called to avoid data being free.
|
||||
self._temporary_data = None
|
||||
|
||||
def _get_callbacks(
|
||||
self, allow_host: bool, enable_categorical: bool
|
||||
) -> Tuple[Callable, Callable]:
|
||||
assert hasattr(self, "cache_prefix"), "__init__ is not called."
|
||||
self._reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(
|
||||
self._reset_wrapper
|
||||
)
|
||||
self._next_callback = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
)(self._next_wrapper)
|
||||
self._allow_host = allow_host
|
||||
self._enable_categorical = enable_categorical
|
||||
return self._reset_callback, self._next_callback
|
||||
|
||||
@property
|
||||
def proxy(self):
|
||||
'''Handler of DMatrix proxy.'''
|
||||
def proxy(self) -> "_ProxyDMatrix":
|
||||
"""Handle of DMatrix proxy."""
|
||||
return self._handle
|
||||
|
||||
def reset_wrapper(self, this): # pylint: disable=unused-argument
|
||||
'''A wrapper for user defined `reset` function.'''
|
||||
self.reset()
|
||||
def _handle_exception(self, fn: Callable, dft_ret: _T) -> _T:
|
||||
if self._exception is not None:
|
||||
return dft_ret
|
||||
|
||||
def next_wrapper(self, this): # pylint: disable=unused-argument
|
||||
'''A wrapper for user defined `next` function.
|
||||
try:
|
||||
return fn()
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
# Defer the exception in order to return 0 and stop the iteration.
|
||||
# Exception inside a ctype callback function has no effect except
|
||||
# for printing to stderr (doesn't stop the execution).
|
||||
tb = sys.exc_info()[2]
|
||||
# On dask, the worker is restarted and somehow the information is
|
||||
# lost.
|
||||
self._exception = e.with_traceback(tb)
|
||||
return dft_ret
|
||||
|
||||
def _reraise(self) -> None:
|
||||
self._temporary_data = None
|
||||
if self._exception is not None:
|
||||
# pylint 2.7.0 believes `self._exception` can be None even with `assert
|
||||
# isinstace`
|
||||
exc = self._exception
|
||||
self._exception = None
|
||||
raise exc # pylint: disable=raising-bad-type
|
||||
|
||||
def __del__(self) -> None:
|
||||
assert self._temporary_data is None, self._temporary_data
|
||||
assert self._exception is None
|
||||
|
||||
def _reset_wrapper(self, this: None) -> None: # pylint: disable=unused-argument
|
||||
"""A wrapper for user defined `reset` function."""
|
||||
# free the data
|
||||
self._temporary_data = None
|
||||
self._handle_exception(self.reset, None)
|
||||
|
||||
def _next_wrapper(self, this: None) -> int: # pylint: disable=unused-argument
|
||||
"""A wrapper for user defined `next` function.
|
||||
|
||||
`this` is not used in Python. ctypes can handle `self` of a Python
|
||||
member function automatically when converting it to c function
|
||||
pointer.
|
||||
|
||||
'''
|
||||
if self.exception is not None:
|
||||
return 0
|
||||
|
||||
"""
|
||||
@_deprecate_positional_args
|
||||
def data_handle(
|
||||
data,
|
||||
feature_names=None,
|
||||
feature_types=None,
|
||||
**kwargs
|
||||
data: Any,
|
||||
*,
|
||||
feature_names: Optional[List[str]] = None,
|
||||
feature_types: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
from .data import dispatch_proxy_set_data
|
||||
from .data import _proxy_transform
|
||||
data, feature_names, feature_types = _proxy_transform(
|
||||
data, feature_names, feature_types, self.enable_categorical,
|
||||
|
||||
transformed, feature_names, feature_types = _proxy_transform(
|
||||
data,
|
||||
feature_names,
|
||||
feature_types,
|
||||
self._enable_categorical,
|
||||
)
|
||||
dispatch_proxy_set_data(self.proxy, data, self._allow_host)
|
||||
# Stage the data, meta info are copied inside C++ MetaInfo.
|
||||
self._temporary_data = transformed
|
||||
dispatch_proxy_set_data(self.proxy, transformed, self._allow_host)
|
||||
self.proxy.set_info(
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
**kwargs,
|
||||
)
|
||||
try:
|
||||
# Differ the exception in order to return 0 and stop the iteration.
|
||||
# Exception inside a ctype callback function has no effect except
|
||||
# for printing to stderr (doesn't stop the execution).
|
||||
ret = self.next(data_handle) # pylint: disable=not-callable
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
tb = sys.exc_info()[2]
|
||||
# On dask the worker is restarted and somehow the information is
|
||||
# lost.
|
||||
self.exception = e.with_traceback(tb)
|
||||
return 0
|
||||
return ret
|
||||
# pylint: disable=not-callable
|
||||
return self._handle_exception(lambda: self.next(data_handle), 0)
|
||||
|
||||
def reset(self):
|
||||
'''Reset the data iterator. Prototype for user defined function.'''
|
||||
def reset(self) -> None:
|
||||
"""Reset the data iterator. Prototype for user defined function."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def next(self, input_data):
|
||||
'''Set the next batch of data.
|
||||
def next(self, input_data: Callable) -> int:
|
||||
"""Set the next batch of data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
data_handle: callable
|
||||
data_handle:
|
||||
A function with same data fields like `data`, `label` with
|
||||
`xgboost.DMatrix`.
|
||||
|
||||
@@ -392,7 +444,7 @@ class DataIter:
|
||||
-------
|
||||
0 if there's no more batch, otherwise 1.
|
||||
|
||||
'''
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -546,7 +598,12 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
self.handle = None
|
||||
return
|
||||
|
||||
from .data import dispatch_data_backend
|
||||
from .data import dispatch_data_backend, _is_iter
|
||||
|
||||
if _is_iter(data):
|
||||
self._init_from_iter(data, enable_categorical)
|
||||
assert self.handle is not None
|
||||
return
|
||||
|
||||
handle, feature_names, feature_types = dispatch_data_backend(
|
||||
data,
|
||||
@@ -575,6 +632,33 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if feature_types is not None:
|
||||
self.feature_types = feature_types
|
||||
|
||||
def _init_from_iter(self, iterator: DataIter, enable_categorical: bool):
|
||||
it = iterator
|
||||
args = {
|
||||
"missing": self.missing,
|
||||
"nthread": self.nthread,
|
||||
"cache_prefix": it.cache_prefix if it.cache_prefix else "",
|
||||
}
|
||||
args = from_pystr_to_cstr(json.dumps(args))
|
||||
handle = ctypes.c_void_p()
|
||||
# pylint: disable=protected-access
|
||||
reset_callback, next_callback = it._get_callbacks(
|
||||
True, enable_categorical
|
||||
)
|
||||
ret = _LIB.XGDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
reset_callback,
|
||||
next_callback,
|
||||
args,
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
# pylint: disable=protected-access
|
||||
it._reraise()
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "handle") and self.handle:
|
||||
_check_call(_LIB.XGDMatrixFree(self.handle))
|
||||
@@ -907,7 +991,8 @@ class DMatrix: # pylint: disable=too-many-instance-attributes
|
||||
if len(feature_names) != len(set(feature_names)):
|
||||
raise ValueError('feature_names must be unique')
|
||||
if len(feature_names) != self.num_col() and self.num_col() != 0:
|
||||
msg = 'feature_names must have the same length as data'
|
||||
msg = ("feature_names must have the same length as data, ",
|
||||
f"expected {self.num_col()}, got {len(feature_names)}")
|
||||
raise ValueError(msg)
|
||||
# prohibit to use symbols may affect to parse. e.g. []<
|
||||
if not all(isinstance(f, str) and
|
||||
@@ -1001,30 +1086,44 @@ class _ProxyDMatrix(DMatrix):
|
||||
inplace_predict).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self): # pylint: disable=super-init-not-called
|
||||
self.handle = ctypes.c_void_p()
|
||||
_check_call(_LIB.XGProxyDMatrixCreate(ctypes.byref(self.handle)))
|
||||
|
||||
def _set_data_from_cuda_interface(self, data):
|
||||
'''Set data from CUDA array interface.'''
|
||||
"""Set data from CUDA array interface."""
|
||||
interface = data.__cuda_array_interface__
|
||||
interface_str = bytes(json.dumps(interface, indent=2), 'utf-8')
|
||||
interface_str = bytes(json.dumps(interface, indent=2), "utf-8")
|
||||
_check_call(
|
||||
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(
|
||||
self.handle,
|
||||
interface_str
|
||||
)
|
||||
_LIB.XGProxyDMatrixSetDataCudaArrayInterface(self.handle, interface_str)
|
||||
)
|
||||
|
||||
def _set_data_from_cuda_columnar(self, data):
|
||||
'''Set data from CUDA columnar format.'''
|
||||
"""Set data from CUDA columnar format."""
|
||||
from .data import _cudf_array_interfaces
|
||||
|
||||
_, interfaces_str = _cudf_array_interfaces(data)
|
||||
_check_call(_LIB.XGProxyDMatrixSetDataCudaColumnar(self.handle, interfaces_str))
|
||||
|
||||
def _set_data_from_array(self, data: np.ndarray):
|
||||
"""Set data from numpy array."""
|
||||
from .data import _array_interface
|
||||
|
||||
_check_call(
|
||||
_LIB.XGProxyDMatrixSetDataCudaColumnar(
|
||||
self.handle,
|
||||
interfaces_str
|
||||
)
|
||||
_LIB.XGProxyDMatrixSetDataDense(self.handle, _array_interface(data))
|
||||
)
|
||||
|
||||
def _set_data_from_csr(self, csr):
|
||||
"""Set data from scipy csr"""
|
||||
from .data import _array_interface
|
||||
|
||||
_LIB.XGProxyDMatrixSetDataCSR(
|
||||
self.handle,
|
||||
_array_interface(csr.indptr),
|
||||
_array_interface(csr.indices),
|
||||
_array_interface(csr.data),
|
||||
ctypes.c_size_t(csr.shape[1]),
|
||||
)
|
||||
|
||||
|
||||
@@ -1110,13 +1209,14 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
else:
|
||||
it = SingleBatchInternalIter(data=data, **meta)
|
||||
|
||||
it.enable_categorical = enable_categorical
|
||||
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
||||
next_callback = ctypes.CFUNCTYPE(
|
||||
ctypes.c_int,
|
||||
ctypes.c_void_p,
|
||||
)(it.next_wrapper)
|
||||
handle = ctypes.c_void_p()
|
||||
# pylint: disable=protected-access
|
||||
reset_callback, next_callback = it._get_callbacks(False, enable_categorical)
|
||||
if it.cache_prefix is not None:
|
||||
raise ValueError(
|
||||
"DeviceQuantileDMatrix doesn't cache data, remove the cache_prefix "
|
||||
"in iterator to fix this error."
|
||||
)
|
||||
ret = _LIB.XGDeviceQuantileDMatrixCreateFromCallback(
|
||||
None,
|
||||
it.proxy.handle,
|
||||
@@ -1127,10 +1227,8 @@ class DeviceQuantileDMatrix(DMatrix):
|
||||
ctypes.c_int(self.max_bin),
|
||||
ctypes.byref(handle),
|
||||
)
|
||||
if it.exception is not None:
|
||||
# pylint 2.7.0 believes `it.exception` can be None even with `assert
|
||||
# isinstace`
|
||||
raise it.exception # pylint: disable=raising-bad-type
|
||||
# pylint: disable=protected-access
|
||||
it._reraise()
|
||||
# delay check_call to throw intermediate exception first
|
||||
_check_call(ret)
|
||||
self.handle = handle
|
||||
@@ -2241,8 +2339,8 @@ class Booster(object):
|
||||
# pylint: disable=too-many-locals
|
||||
fmap = os.fspath(os.path.expanduser(fmap))
|
||||
if not PANDAS_INSTALLED:
|
||||
raise Exception(('pandas must be available to use this method.'
|
||||
'Install pandas before calling again.'))
|
||||
raise ImportError(('pandas must be available to use this method.'
|
||||
'Install pandas before calling again.'))
|
||||
|
||||
if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}:
|
||||
raise ValueError('This method is not defined for Booster type {}'
|
||||
|
||||
@@ -5,7 +5,7 @@ import ctypes
|
||||
import json
|
||||
import warnings
|
||||
import os
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, Tuple, Callable
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -238,10 +238,13 @@ def _transform_pandas_df(data, enable_categorical,
|
||||
if meta and len(data.columns) > 1:
|
||||
raise ValueError(
|
||||
'DataFrame for {meta} cannot have multiple columns'.format(
|
||||
meta=meta))
|
||||
meta=meta)
|
||||
)
|
||||
|
||||
dtype = meta_type if meta_type else np.float32
|
||||
data = np.ascontiguousarray(data.values, dtype=dtype)
|
||||
data = data.values
|
||||
if meta_type:
|
||||
data = data.astype(meta_type)
|
||||
return data, feature_names, feature_types
|
||||
|
||||
|
||||
@@ -759,19 +762,19 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
||||
area for meta info.
|
||||
|
||||
'''
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs: Any):
|
||||
self.kwargs = kwargs
|
||||
self.it = 0 # pylint: disable=invalid-name
|
||||
super().__init__()
|
||||
|
||||
def next(self, input_data):
|
||||
def next(self, input_data: Callable) -> int:
|
||||
if self.it == 1:
|
||||
return 0
|
||||
self.it += 1
|
||||
input_data(**self.kwargs)
|
||||
return 1
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
self.it = 0
|
||||
|
||||
|
||||
@@ -785,6 +788,15 @@ def _proxy_transform(data, feature_names, feature_types, enable_categorical):
|
||||
return data, feature_names, feature_types
|
||||
if _is_dlpack(data):
|
||||
return _transform_dlpack(data), feature_names, feature_types
|
||||
if _is_numpy_array(data):
|
||||
return data, feature_names, feature_types
|
||||
if _is_scipy_csr(data):
|
||||
return data, feature_names, feature_types
|
||||
if _is_pandas_df(data):
|
||||
arr, feature_names, feature_types = _transform_pandas_df(
|
||||
data, enable_categorical, feature_names, feature_types
|
||||
)
|
||||
return arr, feature_names, feature_types
|
||||
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
||||
|
||||
|
||||
@@ -803,7 +815,16 @@ def dispatch_proxy_set_data(proxy: _ProxyDMatrix, data: Any, allow_host: bool) -
|
||||
data = _transform_dlpack(data)
|
||||
proxy._set_data_from_cuda_interface(data) # pylint: disable=W0212
|
||||
return
|
||||
# Part of https://github.com/dmlc/xgboost/pull/7070
|
||||
assert allow_host is False, "host data is not yet supported."
|
||||
raise TypeError('Value type is not supported for data iterator:' +
|
||||
str(type(data)))
|
||||
|
||||
err = TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
||||
|
||||
if not allow_host:
|
||||
raise err
|
||||
|
||||
if _is_numpy_array(data):
|
||||
proxy._set_data_from_array(data) # pylint: disable=W0212
|
||||
return
|
||||
if _is_scipy_csr(data):
|
||||
proxy._set_data_from_csr(data) # pylint: disable=W0212
|
||||
return
|
||||
raise err
|
||||
|
||||
Reference in New Issue
Block a user