[doc] Document Python inputs. (#8643)
This commit is contained in:
parent
4e12f3e1bc
commit
1b58d81315
@ -32,24 +32,9 @@ To verify your installation, run the following in Python:
|
|||||||
|
|
||||||
Data Interface
|
Data Interface
|
||||||
--------------
|
--------------
|
||||||
The XGBoost python module is able to load data from many different types of data format,
|
The XGBoost Python module is able to load data from many different types of data format including both CPU and GPU data structures. For a complete list of supported data types, please reference the :ref:`py-data`. For a detailed description of text input formats, please visit :doc:`/tutorials/input_format`.
|
||||||
including:
|
|
||||||
|
|
||||||
- NumPy 2D array
|
The input data is stored in a :py:class:`DMatrix <xgboost.DMatrix>` object. For the sklearn estimator interface, a :py:class:`DMatrix` or a :py:class:`QuantileDMatrix` is created depending on the chosen algorithm and the input, see the sklearn API reference for details. We will illustrate some of the basic input types with the ``DMatrix`` here.
|
||||||
- SciPy 2D sparse array
|
|
||||||
- Pandas data frame
|
|
||||||
- cuDF DataFrame
|
|
||||||
- cupy 2D array
|
|
||||||
- dlpack
|
|
||||||
- datatable
|
|
||||||
- XGBoost binary buffer file.
|
|
||||||
- LIBSVM text format file
|
|
||||||
- Comma-separated values (CSV) file
|
|
||||||
- Arrow table.
|
|
||||||
|
|
||||||
(See :doc:`/tutorials/input_format` for detailed description of text input format.)
|
|
||||||
|
|
||||||
The data is stored in a :py:class:`DMatrix <xgboost.DMatrix>` object.
|
|
||||||
|
|
||||||
* To load a NumPy array into :py:class:`DMatrix <xgboost.DMatrix>`:
|
* To load a NumPy array into :py:class:`DMatrix <xgboost.DMatrix>`:
|
||||||
|
|
||||||
@ -120,6 +105,81 @@ to number of groups.
|
|||||||
recommended to use pandas ``read_csv`` or other similar utilites than XGBoost's builtin
|
recommended to use pandas ``read_csv`` or other similar utilites than XGBoost's builtin
|
||||||
parser.
|
parser.
|
||||||
|
|
||||||
|
.. _py-data:
|
||||||
|
|
||||||
|
Supported data structures for various XGBoost functions
|
||||||
|
=======================================================
|
||||||
|
|
||||||
|
*******
|
||||||
|
Markers
|
||||||
|
*******
|
||||||
|
|
||||||
|
- T: Supported.
|
||||||
|
- F: Not supported.
|
||||||
|
- NE: Invalid type for the use case. For instance, `pd.Series` can not be multi-target label.
|
||||||
|
- NPA: Support with the help of numpy array.
|
||||||
|
- CPA: Support with the help of cupy array.
|
||||||
|
- SciCSR: Support with the help of scripy sparse CSR. The conversion to scipy CSR may or may not be possible. Raise a type error if conversion fails.
|
||||||
|
- FF: We can look forward to having its support in recent future if requested.
|
||||||
|
- empty: To be filled in.
|
||||||
|
|
||||||
|
************
|
||||||
|
Table Header
|
||||||
|
************
|
||||||
|
- `X` means predictor matrix.
|
||||||
|
- Meta info: label, weight, etc.
|
||||||
|
- Multi Label: 2-dim label for multi-target.
|
||||||
|
- Others: Anything else that we don't list here explicitly including formats like `lil`, `dia`, `bsr`. XGBoost will try to convert it into scipy csr.
|
||||||
|
|
||||||
|
**************
|
||||||
|
Support Matrix
|
||||||
|
**************
|
||||||
|
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| Name | DMatrix X | QuantileDMatrix X | Sklearn X | Meta Info | Inplace prediction | Multi Label |
|
||||||
|
+=========================+===========+===================+===========+===========+====================+=============+
|
||||||
|
| numpy.ndarray | T | T | T | T | T | T |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| scipy.sparse.csr | T | T | T | NE | T | F |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| scipy.sparse.csc | T | F | T | NE | F | F |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| scipy.sparse.coo | SciCSR | F | SciCSR | NE | F | F |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| uri | T | F | F | F | NE | F |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| list | NPA | NPA | NPA | NPA | NPA | T |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| tuple | NPA | NPA | NPA | NPA | NPA | T |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| pandas.DataFrame | NPA | NPA | NPA | NPA | NPA | NPA |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| pandas.Series | NPA | NPA | NPA | NPA | NPA | NE |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| cudf.DataFrame | T | T | T | T | T | T |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| cudf.Series | T | T | T | T | FF | NE |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| cupy.ndarray | T | T | T | T | T | T |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| dlpack | CPA | CPA | | CPA | FF | FF |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| datatable.Frame | T | FF | | NPA | FF | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| datatable.Table | T | FF | | NPA | FF | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| modin.DataFrame | NPA | FF | NPA | NPA | FF | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| modin.Series | NPA | FF | NPA | NPA | FF | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| pyarrow.Table | T | F | | NPA | FF | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| pyarrow.dataset.Dataset | T | F | | | F | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| _\_array\_\_ | NPA | F | NPA | NPA | H | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
| Others | SciCSR | F | | F | F | |
|
||||||
|
+-------------------------+-----------+-------------------+-----------+-----------+--------------------+-------------+
|
||||||
|
|
||||||
Setting Parameters
|
Setting Parameters
|
||||||
------------------
|
------------------
|
||||||
|
|||||||
@ -619,11 +619,11 @@ class DataSplitMode(IntEnum):
|
|||||||
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-methods
|
||||||
"""Data Matrix used in XGBoost.
|
"""Data Matrix used in XGBoost.
|
||||||
|
|
||||||
DMatrix is an internal data structure that is used by XGBoost,
|
DMatrix is an internal data structure that is used by XGBoost, which is optimized
|
||||||
which is optimized for both memory efficiency and training speed.
|
for both memory efficiency and training speed. You can construct DMatrix from
|
||||||
You can construct DMatrix from multiple different sources of data.
|
multiple different sources of data.
|
||||||
"""
|
|
||||||
|
|
||||||
|
"""
|
||||||
@_deprecate_positional_args
|
@_deprecate_positional_args
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -647,15 +647,9 @@ class DMatrix: # pylint: disable=too-many-instance-attributes,too-many-public-m
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Parameters
|
"""Parameters
|
||||||
----------
|
----------
|
||||||
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
|
data :
|
||||||
dt.Frame/cudf.DataFrame/cupy.array/dlpack/arrow.Table
|
Data source of DMatrix. See :ref:`py-data` for a list of supported input
|
||||||
|
types.
|
||||||
Data source of DMatrix.
|
|
||||||
|
|
||||||
When data is string or os.PathLike type, it represents the path libsvm
|
|
||||||
format txt file, csv file (by specifying uri parameter
|
|
||||||
'path_to_csv?format=csv'), or binary file that xgboost can read from.
|
|
||||||
|
|
||||||
label : array_like
|
label : array_like
|
||||||
Label of the training data.
|
Label of the training data.
|
||||||
weight : array_like
|
weight : array_like
|
||||||
|
|||||||
@ -939,7 +939,14 @@ class XGBModel(XGBModelBase):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X :
|
X :
|
||||||
Feature matrix
|
Feature matrix. See :ref:`py-data` for a list of supported types.
|
||||||
|
|
||||||
|
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
|
||||||
|
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
||||||
|
for conserving memory. However, this has performance implications when the
|
||||||
|
device of input data is not matched with algorithm. For instance, if the
|
||||||
|
input is a numpy array on CPU but ``gpu_hist`` is used for training, then
|
||||||
|
the data is first processed on CPU then transferred to GPU.
|
||||||
y :
|
y :
|
||||||
Labels
|
Labels
|
||||||
sample_weight :
|
sample_weight :
|
||||||
@ -982,6 +989,7 @@ class XGBModel(XGBModelBase):
|
|||||||
callbacks :
|
callbacks :
|
||||||
.. deprecated:: 1.6.0
|
.. deprecated:: 1.6.0
|
||||||
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with config_context(verbosity=self.verbosity):
|
with config_context(verbosity=self.verbosity):
|
||||||
evals_result: TrainingCallback.EvalsLog = {}
|
evals_result: TrainingCallback.EvalsLog = {}
|
||||||
@ -1567,7 +1575,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X : array_like
|
X : array_like
|
||||||
Feature matrix.
|
Feature matrix. See :ref:`py-data` for a list of supported types.
|
||||||
ntree_limit : int
|
ntree_limit : int
|
||||||
Deprecated, use `iteration_range` instead.
|
Deprecated, use `iteration_range` instead.
|
||||||
validate_features : bool
|
validate_features : bool
|
||||||
@ -1846,7 +1854,14 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
X :
|
X :
|
||||||
Feature matrix
|
Feature matrix. See :ref:`py-data` for a list of supported types.
|
||||||
|
|
||||||
|
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
|
||||||
|
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
|
||||||
|
for conserving memory. However, this has performance implications when the
|
||||||
|
device of input data is not matched with algorithm. For instance, if the
|
||||||
|
input is a numpy array on CPU but ``gpu_hist`` is used for training, then
|
||||||
|
the data is first processed on CPU then transferred to GPU.
|
||||||
y :
|
y :
|
||||||
Labels
|
Labels
|
||||||
group :
|
group :
|
||||||
@ -1917,6 +1932,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
|||||||
callbacks :
|
callbacks :
|
||||||
.. deprecated:: 1.6.0
|
.. deprecated:: 1.6.0
|
||||||
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# check if group information is provided
|
# check if group information is provided
|
||||||
with config_context(verbosity=self.verbosity):
|
with config_context(verbosity=self.verbosity):
|
||||||
|
|||||||
@ -1084,6 +1084,12 @@ def test_pandas_input():
|
|||||||
)
|
)
|
||||||
np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1]))
|
np.testing.assert_allclose(np.array(clf_isotonic.classes_), np.array([0, 1]))
|
||||||
|
|
||||||
|
train_ser = train["k1"]
|
||||||
|
assert isinstance(train_ser, pd.Series)
|
||||||
|
model = xgb.XGBClassifier(n_estimators=8)
|
||||||
|
model.fit(train_ser, target, eval_set=[(train_ser, target)])
|
||||||
|
assert tm.non_increasing(model.evals_result()["validation_0"]["logloss"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
|
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
|
||||||
def test_feature_weights(tree_method):
|
def test_feature_weights(tree_method):
|
||||||
@ -1239,6 +1245,10 @@ def test_multilabel_classification() -> None:
|
|||||||
np.testing.assert_allclose(clf.predict(X), predt)
|
np.testing.assert_allclose(clf.predict(X), predt)
|
||||||
assert predt.dtype == np.int64
|
assert predt.dtype == np.int64
|
||||||
|
|
||||||
|
y = y.tolist()
|
||||||
|
clf.fit(X, y)
|
||||||
|
np.testing.assert_allclose(clf.predict(X), predt)
|
||||||
|
|
||||||
|
|
||||||
def test_data_initialization():
|
def test_data_initialization():
|
||||||
from sklearn.datasets import load_digits
|
from sklearn.datasets import load_digits
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user