Categorical data support for cuDF. (#7042)
* Add support in DMatrix. * Add support in DQM, except for iterator.
This commit is contained in:
parent
5c2d7a18c9
commit
d9799b09d0
@ -231,17 +231,6 @@ def _numpy2ctypes_type(dtype):
|
|||||||
return _NUMPY_TO_CTYPES_MAPPING[dtype]
|
return _NUMPY_TO_CTYPES_MAPPING[dtype]
|
||||||
|
|
||||||
|
|
||||||
def _array_interface(data: np.ndarray) -> bytes:
|
|
||||||
assert (
|
|
||||||
data.dtype.hasobject is False
|
|
||||||
), "Input data contains `object` dtype. Expecting numeric data."
|
|
||||||
interface = data.__array_interface__
|
|
||||||
if "mask" in interface:
|
|
||||||
interface["mask"] = interface["mask"].__array_interface__
|
|
||||||
interface_str = bytes(json.dumps(interface), "utf-8")
|
|
||||||
return interface_str
|
|
||||||
|
|
||||||
|
|
||||||
def _cuda_array_interface(data) -> bytes:
|
def _cuda_array_interface(data) -> bytes:
|
||||||
assert (
|
assert (
|
||||||
data.dtype.hasobject is False
|
data.dtype.hasobject is False
|
||||||
@ -353,11 +342,17 @@ class DataIter:
|
|||||||
if self.exception is not None:
|
if self.exception is not None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def data_handle(data, feature_names=None, feature_types=None, **kwargs):
|
def data_handle(
|
||||||
|
data,
|
||||||
|
feature_names=None,
|
||||||
|
feature_types=None,
|
||||||
|
enable_categorical=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
from .data import dispatch_device_quantile_dmatrix_set_data
|
from .data import dispatch_device_quantile_dmatrix_set_data
|
||||||
from .data import _device_quantile_transform
|
from .data import _device_quantile_transform
|
||||||
data, feature_names, feature_types = _device_quantile_transform(
|
data, feature_names, feature_types = _device_quantile_transform(
|
||||||
data, feature_names, feature_types
|
data, feature_names, feature_types, enable_categorical,
|
||||||
)
|
)
|
||||||
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
dispatch_device_quantile_dmatrix_set_data(self.proxy, data)
|
||||||
self.proxy.set_info(
|
self.proxy.set_info(
|
||||||
@ -1023,7 +1018,7 @@ class _ProxyDMatrix(DMatrix):
|
|||||||
def _set_data_from_cuda_columnar(self, data):
|
def _set_data_from_cuda_columnar(self, data):
|
||||||
'''Set data from CUDA columnar format.1'''
|
'''Set data from CUDA columnar format.1'''
|
||||||
from .data import _cudf_array_interfaces
|
from .data import _cudf_array_interfaces
|
||||||
interfaces_str = _cudf_array_interfaces(data)
|
_, interfaces_str = _cudf_array_interfaces(data)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
_LIB.XGDeviceQuantileDMatrixSetDataCudaColumnar(
|
||||||
self.handle,
|
self.handle,
|
||||||
@ -1076,10 +1071,6 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
self.handle = data
|
self.handle = data
|
||||||
return
|
return
|
||||||
|
|
||||||
if enable_categorical:
|
|
||||||
raise NotImplementedError(
|
|
||||||
'categorical support is not enabled on DeviceQuantileDMatrix.'
|
|
||||||
)
|
|
||||||
if qid is not None and group is not None:
|
if qid is not None and group is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Only one of the eval_qid or eval_group for each evaluation '
|
'Only one of the eval_qid or eval_group for each evaluation '
|
||||||
@ -1098,9 +1089,10 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
feature_weights=feature_weights,
|
feature_weights=feature_weights,
|
||||||
feature_names=feature_names,
|
feature_names=feature_names,
|
||||||
feature_types=feature_types,
|
feature_types=feature_types,
|
||||||
|
enable_categorical=enable_categorical,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _init(self, data, feature_names, feature_types, **meta):
|
def _init(self, data, enable_categorical, **meta):
|
||||||
from .data import (
|
from .data import (
|
||||||
_is_dlpack,
|
_is_dlpack,
|
||||||
_transform_dlpack,
|
_transform_dlpack,
|
||||||
@ -1114,9 +1106,13 @@ class DeviceQuantileDMatrix(DMatrix):
|
|||||||
data = _transform_dlpack(data)
|
data = _transform_dlpack(data)
|
||||||
if _is_iter(data):
|
if _is_iter(data):
|
||||||
it = data
|
it = data
|
||||||
|
if enable_categorical:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"categorical support is not enabled on data iterator."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
it = SingleBatchInternalIter(
|
it = SingleBatchInternalIter(
|
||||||
data, **meta, feature_names=feature_names, feature_types=feature_types
|
data=data, enable_categorical=enable_categorical, **meta
|
||||||
)
|
)
|
||||||
|
|
||||||
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
reset_callback = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(it.reset_wrapper)
|
||||||
@ -1920,6 +1916,7 @@ class Booster(object):
|
|||||||
f"got {data.shape[1]}"
|
f"got {data.shape[1]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .data import _array_interface
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
from .data import _ensure_np_dtype
|
from .data import _ensure_np_dtype
|
||||||
data, _ = _ensure_np_dtype(data, data.dtype)
|
data, _ = _ensure_np_dtype(data, data.dtype)
|
||||||
@ -1974,7 +1971,7 @@ class Booster(object):
|
|||||||
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
if lazy_isinstance(data, "cudf.core.dataframe", "DataFrame"):
|
||||||
from .data import _cudf_array_interfaces
|
from .data import _cudf_array_interfaces
|
||||||
|
|
||||||
interfaces_str = _cudf_array_interfaces(data)
|
_, interfaces_str = _cudf_array_interfaces(data)
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGBoosterPredictFromCudaColumnar(
|
_LIB.XGBoosterPredictFromCudaColumnar(
|
||||||
self.handle,
|
self.handle,
|
||||||
|
|||||||
@ -5,12 +5,12 @@ import ctypes
|
|||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .core import c_array, _LIB, _check_call, c_str
|
from .core import c_array, _LIB, _check_call, c_str
|
||||||
from .core import _array_interface, _cuda_array_interface
|
from .core import _cuda_array_interface
|
||||||
from .core import DataIter, _ProxyDMatrix, DMatrix
|
from .core import DataIter, _ProxyDMatrix, DMatrix
|
||||||
from .compat import lazy_isinstance
|
from .compat import lazy_isinstance
|
||||||
|
|
||||||
@ -41,6 +41,17 @@ def _is_scipy_csr(data):
|
|||||||
return isinstance(data, scipy.sparse.csr_matrix)
|
return isinstance(data, scipy.sparse.csr_matrix)
|
||||||
|
|
||||||
|
|
||||||
|
def _array_interface(data: np.ndarray) -> bytes:
|
||||||
|
assert (
|
||||||
|
data.dtype.hasobject is False
|
||||||
|
), "Input data contains `object` dtype. Expecting numeric data."
|
||||||
|
interface = data.__array_interface__
|
||||||
|
if "mask" in interface:
|
||||||
|
interface["mask"] = interface["mask"].__array_interface__
|
||||||
|
interface_str = bytes(json.dumps(interface), "utf-8")
|
||||||
|
return interface_str
|
||||||
|
|
||||||
|
|
||||||
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
|
def _from_scipy_csr(data, missing, nthread, feature_names, feature_types):
|
||||||
"""Initialize data from a CSR matrix."""
|
"""Initialize data from a CSR matrix."""
|
||||||
if len(data.indices) != len(data.data):
|
if len(data.indices) != len(data.data):
|
||||||
@ -179,7 +190,7 @@ _pandas_dtype_mapper = {
|
|||||||
'float16': 'float',
|
'float16': 'float',
|
||||||
'float32': 'float',
|
'float32': 'float',
|
||||||
'float64': 'float',
|
'float64': 'float',
|
||||||
'bool': 'i'
|
'bool': 'i',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -349,54 +360,73 @@ def _is_cudf_df(data):
|
|||||||
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
|
return hasattr(cudf, 'DataFrame') and isinstance(data, cudf.DataFrame)
|
||||||
|
|
||||||
|
|
||||||
def _cudf_array_interfaces(data):
|
def _cudf_array_interfaces(data) -> Tuple[list, list]:
|
||||||
'''Extract CuDF __cuda_array_interface__'''
|
"""Extract CuDF __cuda_array_interface__. This is special as it returns a new list of
|
||||||
|
data and a list of array interfaces. The data is list of categorical codes that
|
||||||
|
caller can safely ignore, but have to keep their reference alive until usage of array
|
||||||
|
interface is finished.
|
||||||
|
|
||||||
|
"""
|
||||||
|
from cudf.utils.dtypes import is_categorical_dtype
|
||||||
|
cat_codes = []
|
||||||
interfaces = []
|
interfaces = []
|
||||||
if _is_cudf_ser(data):
|
if _is_cudf_ser(data):
|
||||||
interfaces.append(data.__cuda_array_interface__)
|
interfaces.append(data.__cuda_array_interface__)
|
||||||
else:
|
else:
|
||||||
for col in data:
|
for col in data:
|
||||||
|
if is_categorical_dtype(data[col].dtype):
|
||||||
|
codes = data[col].cat.codes
|
||||||
|
interface = codes.__cuda_array_interface__
|
||||||
|
cat_codes.append(codes)
|
||||||
|
else:
|
||||||
interface = data[col].__cuda_array_interface__
|
interface = data[col].__cuda_array_interface__
|
||||||
if 'mask' in interface:
|
if "mask" in interface:
|
||||||
interface['mask'] = interface['mask'].__cuda_array_interface__
|
interface["mask"] = interface["mask"].__cuda_array_interface__
|
||||||
interfaces.append(interface)
|
interfaces.append(interface)
|
||||||
interfaces_str = bytes(json.dumps(interfaces, indent=2), 'utf-8')
|
interfaces_str = bytes(json.dumps(interfaces, indent=2), "utf-8")
|
||||||
return interfaces_str
|
return cat_codes, interfaces_str
|
||||||
|
|
||||||
|
|
||||||
def _transform_cudf_df(data, feature_names, feature_types):
|
def _transform_cudf_df(data, feature_names, feature_types, enable_categorical):
|
||||||
|
from cudf.utils.dtypes import is_categorical_dtype
|
||||||
|
|
||||||
if feature_names is None:
|
if feature_names is None:
|
||||||
if _is_cudf_ser(data):
|
if _is_cudf_ser(data):
|
||||||
feature_names = [data.name]
|
feature_names = [data.name]
|
||||||
elif lazy_isinstance(
|
elif lazy_isinstance(data.columns, "cudf.core.multiindex", "MultiIndex"):
|
||||||
data.columns, 'cudf.core.multiindex', 'MultiIndex'):
|
feature_names = [" ".join([str(x) for x in i]) for i in data.columns]
|
||||||
feature_names = [
|
|
||||||
' '.join([str(x) for x in i])
|
|
||||||
for i in data.columns
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
feature_names = data.columns.format()
|
feature_names = data.columns.format()
|
||||||
if feature_types is None:
|
if feature_types is None:
|
||||||
|
feature_types = []
|
||||||
if _is_cudf_ser(data):
|
if _is_cudf_ser(data):
|
||||||
dtypes = [data.dtype]
|
dtypes = [data.dtype]
|
||||||
else:
|
else:
|
||||||
dtypes = data.dtypes
|
dtypes = data.dtypes
|
||||||
feature_types = [_pandas_dtype_mapper[d.name]
|
for dtype in dtypes:
|
||||||
for d in dtypes]
|
if is_categorical_dtype(dtype) and enable_categorical:
|
||||||
|
feature_types.append("categorical")
|
||||||
|
else:
|
||||||
|
feature_types.append(_pandas_dtype_mapper[dtype.name])
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
def _from_cudf_df(data, missing, nthread, feature_names, feature_types):
|
def _from_cudf_df(
|
||||||
|
data, missing, nthread, feature_names, feature_types, enable_categorical
|
||||||
|
):
|
||||||
data, feature_names, feature_types = _transform_cudf_df(
|
data, feature_names, feature_types = _transform_cudf_df(
|
||||||
data, feature_names, feature_types)
|
data, feature_names, feature_types, enable_categorical
|
||||||
interfaces_str = _cudf_array_interfaces(data)
|
)
|
||||||
|
_, interfaces_str = _cudf_array_interfaces(data)
|
||||||
handle = ctypes.c_void_p()
|
handle = ctypes.c_void_p()
|
||||||
_check_call(
|
_check_call(
|
||||||
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
_LIB.XGDMatrixCreateFromArrayInterfaceColumns(
|
||||||
interfaces_str,
|
interfaces_str,
|
||||||
ctypes.c_float(missing),
|
ctypes.c_float(missing),
|
||||||
ctypes.c_int(nthread),
|
ctypes.c_int(nthread),
|
||||||
ctypes.byref(handle)))
|
ctypes.byref(handle),
|
||||||
|
)
|
||||||
|
)
|
||||||
return handle, feature_names, feature_types
|
return handle, feature_names, feature_types
|
||||||
|
|
||||||
|
|
||||||
@ -554,12 +584,10 @@ def dispatch_data_backend(data, missing, threads,
|
|||||||
if _is_pandas_series(data):
|
if _is_pandas_series(data):
|
||||||
return _from_pandas_series(data, missing, threads, feature_names,
|
return _from_pandas_series(data, missing, threads, feature_names,
|
||||||
feature_types)
|
feature_types)
|
||||||
if _is_cudf_df(data):
|
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||||
return _from_cudf_df(data, missing, threads, feature_names,
|
return _from_cudf_df(
|
||||||
feature_types)
|
data, missing, threads, feature_names, feature_types, enable_categorical
|
||||||
if _is_cudf_ser(data):
|
)
|
||||||
return _from_cudf_df(data, missing, threads, feature_names,
|
|
||||||
feature_types)
|
|
||||||
if _is_cupy_array(data):
|
if _is_cupy_array(data):
|
||||||
return _from_cupy_array(data, missing, threads, feature_names,
|
return _from_cupy_array(data, missing, threads, feature_names,
|
||||||
feature_types)
|
feature_types)
|
||||||
@ -731,30 +759,8 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
area for meta info.
|
area for meta info.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def __init__(
|
def __init__(self, **kwargs):
|
||||||
self, data,
|
self.kwargs = kwargs
|
||||||
label,
|
|
||||||
weight,
|
|
||||||
base_margin,
|
|
||||||
group,
|
|
||||||
qid,
|
|
||||||
label_lower_bound,
|
|
||||||
label_upper_bound,
|
|
||||||
feature_weights,
|
|
||||||
feature_names,
|
|
||||||
feature_types
|
|
||||||
):
|
|
||||||
self.data = data
|
|
||||||
self.label = label
|
|
||||||
self.weight = weight
|
|
||||||
self.base_margin = base_margin
|
|
||||||
self.group = group
|
|
||||||
self.qid = qid
|
|
||||||
self.label_lower_bound = label_lower_bound
|
|
||||||
self.label_upper_bound = label_upper_bound
|
|
||||||
self.feature_weights = feature_weights
|
|
||||||
self.feature_names = feature_names
|
|
||||||
self.feature_types = feature_types
|
|
||||||
self.it = 0 # pylint: disable=invalid-name
|
self.it = 0 # pylint: disable=invalid-name
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -762,33 +768,24 @@ class SingleBatchInternalIter(DataIter): # pylint: disable=R0902
|
|||||||
if self.it == 1:
|
if self.it == 1:
|
||||||
return 0
|
return 0
|
||||||
self.it += 1
|
self.it += 1
|
||||||
input_data(data=self.data, label=self.label,
|
input_data(**self.kwargs)
|
||||||
weight=self.weight, base_margin=self.base_margin,
|
|
||||||
group=self.group,
|
|
||||||
qid=self.qid,
|
|
||||||
label_lower_bound=self.label_lower_bound,
|
|
||||||
label_upper_bound=self.label_upper_bound,
|
|
||||||
feature_weights=self.feature_weights,
|
|
||||||
feature_names=self.feature_names,
|
|
||||||
feature_types=self.feature_types)
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.it = 0
|
self.it = 0
|
||||||
|
|
||||||
|
|
||||||
def _device_quantile_transform(data, feature_names, feature_types):
|
def _device_quantile_transform(data, feature_names, feature_types, enable_categorical):
|
||||||
if _is_cudf_df(data):
|
if _is_cudf_df(data) or _is_cudf_ser(data):
|
||||||
return _transform_cudf_df(data, feature_names, feature_types)
|
return _transform_cudf_df(
|
||||||
if _is_cudf_ser(data):
|
data, feature_names, feature_types, enable_categorical
|
||||||
return _transform_cudf_df(data, feature_names, feature_types)
|
)
|
||||||
if _is_cupy_array(data):
|
if _is_cupy_array(data):
|
||||||
data = _transform_cupy_array(data)
|
data = _transform_cupy_array(data)
|
||||||
return data, feature_names, feature_types
|
return data, feature_names, feature_types
|
||||||
if _is_dlpack(data):
|
if _is_dlpack(data):
|
||||||
return _transform_dlpack(data), feature_names, feature_types
|
return _transform_dlpack(data), feature_names, feature_types
|
||||||
raise TypeError('Value type is not supported for data iterator:' +
|
raise TypeError("Value type is not supported for data iterator:" + str(type(data)))
|
||||||
str(type(data)))
|
|
||||||
|
|
||||||
|
|
||||||
def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None:
|
def dispatch_device_quantile_dmatrix_set_data(proxy: _ProxyDMatrix, data: Any) -> None:
|
||||||
|
|||||||
@ -171,6 +171,21 @@ Arrow specification.'''
|
|||||||
def test_cudf_metainfo_device_dmatrix(self):
|
def test_cudf_metainfo_device_dmatrix(self):
|
||||||
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
_test_cudf_metainfo(xgb.DeviceQuantileDMatrix)
|
||||||
|
|
||||||
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
|
def test_categorical(self):
|
||||||
|
import cudf
|
||||||
|
_X, _y = tm.make_categorical(100, 30, 17, False)
|
||||||
|
X = cudf.from_pandas(_X)
|
||||||
|
y = cudf.from_pandas(_y)
|
||||||
|
|
||||||
|
Xy = xgb.DMatrix(X, y, enable_categorical=True)
|
||||||
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
|
assert all(t == "categorical" for t in Xy.feature_types)
|
||||||
|
|
||||||
|
Xy = xgb.DeviceQuantileDMatrix(X, y, enable_categorical=True)
|
||||||
|
assert len(Xy.feature_types) == X.shape[1]
|
||||||
|
assert all(t == "categorical" for t in Xy.feature_types)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
@pytest.mark.skipif(**tm.no_cupy())
|
@pytest.mark.skipif(**tm.no_cupy())
|
||||||
|
|||||||
@ -43,22 +43,8 @@ class TestGPUUpdaters:
|
|||||||
assert tm.non_increasing(result['train'][dataset.metric])
|
assert tm.non_increasing(result['train'][dataset.metric])
|
||||||
|
|
||||||
def run_categorical_basic(self, rows, cols, rounds, cats):
|
def run_categorical_basic(self, rows, cols, rounds, cats):
|
||||||
import pandas as pd
|
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||||
rng = np.random.RandomState(1994)
|
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||||
|
|
||||||
pd_dict = {}
|
|
||||||
for i in range(cols):
|
|
||||||
c = rng.randint(low=0, high=cats+1, size=rows)
|
|
||||||
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
|
||||||
|
|
||||||
df = pd.DataFrame(pd_dict)
|
|
||||||
label = df.iloc[:, 0]
|
|
||||||
for i in range(0, cols-1):
|
|
||||||
label += df.iloc[:, i]
|
|
||||||
label += 1
|
|
||||||
df = df.astype('category')
|
|
||||||
onehot = pd.get_dummies(df)
|
|
||||||
cat = df
|
|
||||||
|
|
||||||
by_etl_results = {}
|
by_etl_results = {}
|
||||||
by_builtin_results = {}
|
by_builtin_results = {}
|
||||||
|
|||||||
@ -234,6 +234,34 @@ def get_mq2008(dpath):
|
|||||||
x_valid, y_valid, qid_valid)
|
x_valid, y_valid, qid_valid)
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def make_categorical(
|
||||||
|
n_samples: int, n_features: int, n_categories: int, onehot_enc: bool
|
||||||
|
):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
|
pd_dict = {}
|
||||||
|
for i in range(n_features + 1):
|
||||||
|
c = rng.randint(low=0, high=n_categories + 1, size=n_samples)
|
||||||
|
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)
|
||||||
|
|
||||||
|
df = pd.DataFrame(pd_dict)
|
||||||
|
label = df.iloc[:, 0]
|
||||||
|
df = df.iloc[:, 1:]
|
||||||
|
for i in range(0, n_features):
|
||||||
|
label += df.iloc[:, i]
|
||||||
|
label += 1
|
||||||
|
|
||||||
|
df = df.astype("category")
|
||||||
|
if onehot_enc:
|
||||||
|
cat = pd.get_dummies(df)
|
||||||
|
else:
|
||||||
|
cat = df
|
||||||
|
return cat, label
|
||||||
|
|
||||||
|
|
||||||
_unweighted_datasets_strategy = strategies.sampled_from(
|
_unweighted_datasets_strategy = strategies.sampled_from(
|
||||||
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
|
[TestDataset('boston', get_boston, 'reg:squarederror', 'rmse'),
|
||||||
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),
|
TestDataset('digits', get_digits, 'multi:softmax', 'mlogloss'),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user