Initial support for multi-target tree. (#8616)
* Implement multi-target for hist. - Add new hist tree builder. - Move data fetchers for tests. - Dispatch function calls in gbm base on the tree type.
This commit is contained in:
parent
ea04d4c46c
commit
151882dd26
@ -7,6 +7,12 @@ The demo is adopted from scikit-learn:
|
|||||||
https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py
|
https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py
|
||||||
|
|
||||||
See :doc:`/tutorials/multioutput` for more information.
|
See :doc:`/tutorials/multioutput` for more information.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
The feature is experimental. For the `multi_output_tree` strategy, many features are
|
||||||
|
missing.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
@ -40,11 +46,18 @@ def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
|
|||||||
return X, y
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
def rmse_model(plot_result: bool):
|
def rmse_model(plot_result: bool, strategy: str):
|
||||||
"""Draw a circle with 2-dim coordinate as target variables."""
|
"""Draw a circle with 2-dim coordinate as target variables."""
|
||||||
X, y = gen_circle()
|
X, y = gen_circle()
|
||||||
# Train a regressor on it
|
# Train a regressor on it
|
||||||
reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64)
|
reg = xgb.XGBRegressor(
|
||||||
|
tree_method="hist",
|
||||||
|
n_estimators=128,
|
||||||
|
n_jobs=16,
|
||||||
|
max_depth=8,
|
||||||
|
multi_strategy=strategy,
|
||||||
|
subsample=0.6,
|
||||||
|
)
|
||||||
reg.fit(X, y, eval_set=[(X, y)])
|
reg.fit(X, y, eval_set=[(X, y)])
|
||||||
|
|
||||||
y_predt = reg.predict(X)
|
y_predt = reg.predict(X)
|
||||||
@ -52,7 +65,7 @@ def rmse_model(plot_result: bool):
|
|||||||
plot_predt(y, y_predt, "multi")
|
plot_predt(y, y_predt, "multi")
|
||||||
|
|
||||||
|
|
||||||
def custom_rmse_model(plot_result: bool) -> None:
|
def custom_rmse_model(plot_result: bool, strategy: str) -> None:
|
||||||
"""Train using Python implementation of Squared Error."""
|
"""Train using Python implementation of Squared Error."""
|
||||||
|
|
||||||
# As the experimental support status, custom objective doesn't support matrix as
|
# As the experimental support status, custom objective doesn't support matrix as
|
||||||
@ -88,9 +101,10 @@ def custom_rmse_model(plot_result: bool) -> None:
|
|||||||
{
|
{
|
||||||
"tree_method": "hist",
|
"tree_method": "hist",
|
||||||
"num_target": y.shape[1],
|
"num_target": y.shape[1],
|
||||||
|
"multi_strategy": strategy,
|
||||||
},
|
},
|
||||||
dtrain=Xy,
|
dtrain=Xy,
|
||||||
num_boost_round=100,
|
num_boost_round=128,
|
||||||
obj=squared_log,
|
obj=squared_log,
|
||||||
evals=[(Xy, "Train")],
|
evals=[(Xy, "Train")],
|
||||||
evals_result=results,
|
evals_result=results,
|
||||||
@ -107,6 +121,16 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
|
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# Train with builtin RMSE objective
|
# Train with builtin RMSE objective
|
||||||
rmse_model(args.plot == 1)
|
# - One model per output.
|
||||||
|
rmse_model(args.plot == 1, "one_output_per_tree")
|
||||||
|
|
||||||
|
# - One model for all outputs, this is still working in progress, many features are
|
||||||
|
# missing.
|
||||||
|
rmse_model(args.plot == 1, "multi_output_tree")
|
||||||
|
|
||||||
# Train with custom objective.
|
# Train with custom objective.
|
||||||
custom_rmse_model(args.plot == 1)
|
# - One model per output.
|
||||||
|
custom_rmse_model(args.plot == 1, "one_output_per_tree")
|
||||||
|
# - One model for all outputs, this is still working in progress, many features are
|
||||||
|
# missing.
|
||||||
|
custom_rmse_model(args.plot == 1, "multi_output_tree")
|
||||||
|
|||||||
@ -226,6 +226,18 @@ Parameters for Tree Booster
|
|||||||
list is a group of indices of features that are allowed to interact with each other.
|
list is a group of indices of features that are allowed to interact with each other.
|
||||||
See :doc:`/tutorials/feature_interaction_constraint` for more information.
|
See :doc:`/tutorials/feature_interaction_constraint` for more information.
|
||||||
|
|
||||||
|
* ``multi_strategy``, [default = ``one_output_per_tree``]
|
||||||
|
|
||||||
|
.. versionadded:: 2.0.0
|
||||||
|
|
||||||
|
.. note:: This parameter is working-in-progress.
|
||||||
|
|
||||||
|
- The strategy used for training multi-target models, including multi-target regression
|
||||||
|
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
|
||||||
|
|
||||||
|
- ``one_output_per_tree``: One model for each target.
|
||||||
|
- ``multi_output_tree``: Use multi-target trees.
|
||||||
|
|
||||||
.. _cat-param:
|
.. _cat-param:
|
||||||
|
|
||||||
Parameters for Categorical Feature
|
Parameters for Categorical Feature
|
||||||
|
|||||||
@ -11,7 +11,11 @@ can be simultaneously classified as both sci-fi and comedy. For detailed explan
|
|||||||
terminologies related to different multi-output models please refer to the
|
terminologies related to different multi-output models please refer to the
|
||||||
:doc:`scikit-learn user guide <sklearn:modules/multiclass>`.
|
:doc:`scikit-learn user guide <sklearn:modules/multiclass>`.
|
||||||
|
|
||||||
Internally, XGBoost builds one model for each target similar to sklearn meta estimators,
|
**********************************
|
||||||
|
Training with One-Model-Per-Target
|
||||||
|
**********************************
|
||||||
|
|
||||||
|
By default, XGBoost builds one model for each target similar to sklearn meta estimators,
|
||||||
with the added benefit of reusing data and other integrated features like SHAP. For a
|
with the added benefit of reusing data and other integrated features like SHAP. For a
|
||||||
worked example of regression, see
|
worked example of regression, see
|
||||||
:ref:`sphx_glr_python_examples_multioutput_regression.py`. For multi-label classification,
|
:ref:`sphx_glr_python_examples_multioutput_regression.py`. For multi-label classification,
|
||||||
@ -36,3 +40,26 @@ dense matrix for labels.
|
|||||||
|
|
||||||
|
|
||||||
The feature is still under development with limited support from objectives and metrics.
|
The feature is still under development with limited support from objectives and metrics.
|
||||||
|
|
||||||
|
*************************
|
||||||
|
Training with Vector Leaf
|
||||||
|
*************************
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
This is still working-in-progress, and many features are missing.
|
||||||
|
|
||||||
|
XGBoost can optionally build multi-output trees with the size of leaf equals to the number
|
||||||
|
of targets when the tree method `hist` is used. The behavior can be controlled by the
|
||||||
|
``multi_strategy`` training parameter, which can take the value `one_output_per_tree` (the
|
||||||
|
default) for building one model per-target or `multi_output_tree` for building
|
||||||
|
multi-output trees.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
clf = xgb.XGBClassifier(tree_method="hist", multi_strategy="multi_output_tree")
|
||||||
|
|
||||||
|
See :ref:`sphx_glr_python_examples_multioutput_regression.py` for a worked example with
|
||||||
|
regression.
|
||||||
|
|||||||
@ -286,8 +286,8 @@ struct LearnerModelParamLegacy;
|
|||||||
* \brief Strategy for building multi-target models.
|
* \brief Strategy for building multi-target models.
|
||||||
*/
|
*/
|
||||||
enum class MultiStrategy : std::int32_t {
|
enum class MultiStrategy : std::int32_t {
|
||||||
kComposite = 0,
|
kOneOutputPerTree = 0,
|
||||||
kMonolithic = 1,
|
kMultiOutputTree = 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -317,7 +317,7 @@ struct LearnerModelParam {
|
|||||||
/**
|
/**
|
||||||
* \brief Strategy for building multi-target models.
|
* \brief Strategy for building multi-target models.
|
||||||
*/
|
*/
|
||||||
MultiStrategy multi_strategy{MultiStrategy::kComposite};
|
MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree};
|
||||||
|
|
||||||
LearnerModelParam() = default;
|
LearnerModelParam() = default;
|
||||||
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
|
||||||
@ -338,7 +338,7 @@ struct LearnerModelParam {
|
|||||||
|
|
||||||
void Copy(LearnerModelParam const& that);
|
void Copy(LearnerModelParam const& that);
|
||||||
[[nodiscard]] bool IsVectorLeaf() const noexcept {
|
[[nodiscard]] bool IsVectorLeaf() const noexcept {
|
||||||
return multi_strategy == MultiStrategy::kMonolithic;
|
return multi_strategy == MultiStrategy::kMultiOutputTree;
|
||||||
}
|
}
|
||||||
[[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; }
|
[[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; }
|
||||||
[[nodiscard]] bst_target_t LeafLength() const noexcept {
|
[[nodiscard]] bst_target_t LeafLength() const noexcept {
|
||||||
|
|||||||
@ -530,17 +530,17 @@ class TensorView {
|
|||||||
/**
|
/**
|
||||||
* \brief Number of items in the tensor.
|
* \brief Number of items in the tensor.
|
||||||
*/
|
*/
|
||||||
LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; }
|
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
|
||||||
/**
|
/**
|
||||||
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
|
||||||
*/
|
*/
|
||||||
LINALG_HD [[nodiscard]] bool Contiguous() const {
|
[[nodiscard]] LINALG_HD bool Contiguous() const {
|
||||||
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
|
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
|
||||||
}
|
}
|
||||||
/**
|
/**
|
||||||
* \brief Whether it's a c-contiguous array.
|
* \brief Whether it's a c-contiguous array.
|
||||||
*/
|
*/
|
||||||
LINALG_HD [[nodiscard]] bool CContiguous() const {
|
[[nodiscard]] LINALG_HD bool CContiguous() const {
|
||||||
StrideT stride;
|
StrideT stride;
|
||||||
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
||||||
// It's contiguous if the stride can be calculated from shape.
|
// It's contiguous if the stride can be calculated from shape.
|
||||||
@ -550,7 +550,7 @@ class TensorView {
|
|||||||
/**
|
/**
|
||||||
* \brief Whether it's a f-contiguous array.
|
* \brief Whether it's a f-contiguous array.
|
||||||
*/
|
*/
|
||||||
LINALG_HD [[nodiscard]] bool FContiguous() const {
|
[[nodiscard]] LINALG_HD bool FContiguous() const {
|
||||||
StrideT stride;
|
StrideT stride;
|
||||||
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
|
||||||
// It's contiguous if the stride can be calculated from shape.
|
// It's contiguous if the stride can be calculated from shape.
|
||||||
|
|||||||
@ -312,6 +312,19 @@ __model_doc = f"""
|
|||||||
needs to be set to have categorical feature support. See :doc:`Categorical Data
|
needs to be set to have categorical feature support. See :doc:`Categorical Data
|
||||||
</tutorials/categorical>` and :ref:`cat-param` for details.
|
</tutorials/categorical>` and :ref:`cat-param` for details.
|
||||||
|
|
||||||
|
multi_strategy : Optional[str]
|
||||||
|
|
||||||
|
.. versionadded:: 2.0.0
|
||||||
|
|
||||||
|
.. note:: This parameter is working-in-progress.
|
||||||
|
|
||||||
|
The strategy used for training multi-target models, including multi-target
|
||||||
|
regression and multi-class classification. See :doc:`/tutorials/multioutput` for
|
||||||
|
more information.
|
||||||
|
|
||||||
|
- ``one_output_per_tree``: One model for each target.
|
||||||
|
- ``multi_output_tree``: Use multi-target trees.
|
||||||
|
|
||||||
eval_metric : Optional[Union[str, List[str], Callable]]
|
eval_metric : Optional[Union[str, List[str], Callable]]
|
||||||
|
|
||||||
.. versionadded:: 1.6.0
|
.. versionadded:: 1.6.0
|
||||||
@ -624,6 +637,7 @@ class XGBModel(XGBModelBase):
|
|||||||
feature_types: Optional[FeatureTypes] = None,
|
feature_types: Optional[FeatureTypes] = None,
|
||||||
max_cat_to_onehot: Optional[int] = None,
|
max_cat_to_onehot: Optional[int] = None,
|
||||||
max_cat_threshold: Optional[int] = None,
|
max_cat_threshold: Optional[int] = None,
|
||||||
|
multi_strategy: Optional[str] = None,
|
||||||
eval_metric: Optional[Union[str, List[str], Callable]] = None,
|
eval_metric: Optional[Union[str, List[str], Callable]] = None,
|
||||||
early_stopping_rounds: Optional[int] = None,
|
early_stopping_rounds: Optional[int] = None,
|
||||||
callbacks: Optional[List[TrainingCallback]] = None,
|
callbacks: Optional[List[TrainingCallback]] = None,
|
||||||
@ -670,6 +684,7 @@ class XGBModel(XGBModelBase):
|
|||||||
self.feature_types = feature_types
|
self.feature_types = feature_types
|
||||||
self.max_cat_to_onehot = max_cat_to_onehot
|
self.max_cat_to_onehot = max_cat_to_onehot
|
||||||
self.max_cat_threshold = max_cat_threshold
|
self.max_cat_threshold = max_cat_threshold
|
||||||
|
self.multi_strategy = multi_strategy
|
||||||
self.eval_metric = eval_metric
|
self.eval_metric = eval_metric
|
||||||
self.early_stopping_rounds = early_stopping_rounds
|
self.early_stopping_rounds = early_stopping_rounds
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
|
|||||||
@ -10,11 +10,9 @@ import os
|
|||||||
import platform
|
import platform
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import zipfile
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from pathlib import Path
|
|
||||||
from platform import system
|
from platform import system
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@ -29,7 +27,6 @@ from typing import (
|
|||||||
TypedDict,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from urllib import request
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -38,6 +35,13 @@ from scipy import sparse
|
|||||||
import xgboost as xgb
|
import xgboost as xgb
|
||||||
from xgboost.core import ArrayLike
|
from xgboost.core import ArrayLike
|
||||||
from xgboost.sklearn import SklObjective
|
from xgboost.sklearn import SklObjective
|
||||||
|
from xgboost.testing.data import (
|
||||||
|
get_california_housing,
|
||||||
|
get_cancer,
|
||||||
|
get_digits,
|
||||||
|
get_sparse,
|
||||||
|
memory,
|
||||||
|
)
|
||||||
|
|
||||||
hypothesis = pytest.importorskip("hypothesis")
|
hypothesis = pytest.importorskip("hypothesis")
|
||||||
|
|
||||||
@ -45,13 +49,8 @@ hypothesis = pytest.importorskip("hypothesis")
|
|||||||
from hypothesis import strategies
|
from hypothesis import strategies
|
||||||
from hypothesis.extra.numpy import arrays
|
from hypothesis.extra.numpy import arrays
|
||||||
|
|
||||||
joblib = pytest.importorskip("joblib")
|
|
||||||
datasets = pytest.importorskip("sklearn.datasets")
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
|
||||||
Memory = joblib.Memory
|
|
||||||
|
|
||||||
memory = Memory("./cachedir", verbose=0)
|
|
||||||
|
|
||||||
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
|
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
|
||||||
|
|
||||||
|
|
||||||
@ -353,137 +352,6 @@ class TestDataset:
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def get_california_housing() -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
data = datasets.fetch_california_housing()
|
|
||||||
return data.data, data.target
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def get_digits() -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
data = datasets.load_digits()
|
|
||||||
return data.data, data.target
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def get_cancer() -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
return datasets.load_breast_cancer(return_X_y=True)
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def get_sparse() -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
rng = np.random.RandomState(199)
|
|
||||||
n = 2000
|
|
||||||
sparsity = 0.75
|
|
||||||
X, y = datasets.make_regression(n, random_state=rng)
|
|
||||||
flag = rng.binomial(1, sparsity, X.shape)
|
|
||||||
for i in range(X.shape[0]):
|
|
||||||
for j in range(X.shape[1]):
|
|
||||||
if flag[i, j]:
|
|
||||||
X[i, j] = np.nan
|
|
||||||
return X, y
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Number of samples: 1460
|
|
||||||
Number of features: 20
|
|
||||||
Number of categorical features: 10
|
|
||||||
Number of numerical features: 10
|
|
||||||
"""
|
|
||||||
from sklearn.datasets import fetch_openml
|
|
||||||
|
|
||||||
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
|
|
||||||
|
|
||||||
categorical_columns_subset: List[str] = [
|
|
||||||
"BldgType", # 5 cats, no nan
|
|
||||||
"GarageFinish", # 3 cats, nan
|
|
||||||
"LotConfig", # 5 cats, no nan
|
|
||||||
"Functional", # 7 cats, no nan
|
|
||||||
"MasVnrType", # 4 cats, nan
|
|
||||||
"HouseStyle", # 8 cats, no nan
|
|
||||||
"FireplaceQu", # 5 cats, nan
|
|
||||||
"ExterCond", # 5 cats, no nan
|
|
||||||
"ExterQual", # 4 cats, no nan
|
|
||||||
"PoolQC", # 3 cats, nan
|
|
||||||
]
|
|
||||||
|
|
||||||
numerical_columns_subset: List[str] = [
|
|
||||||
"3SsnPorch",
|
|
||||||
"Fireplaces",
|
|
||||||
"BsmtHalfBath",
|
|
||||||
"HalfBath",
|
|
||||||
"GarageCars",
|
|
||||||
"TotRmsAbvGrd",
|
|
||||||
"BsmtFinSF1",
|
|
||||||
"BsmtFinSF2",
|
|
||||||
"GrLivArea",
|
|
||||||
"ScreenPorch",
|
|
||||||
]
|
|
||||||
|
|
||||||
X = X[categorical_columns_subset + numerical_columns_subset]
|
|
||||||
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")
|
|
||||||
return X, y
|
|
||||||
|
|
||||||
|
|
||||||
@memory.cache
|
|
||||||
def get_mq2008(
|
|
||||||
dpath: str,
|
|
||||||
) -> Tuple[
|
|
||||||
sparse.csr_matrix,
|
|
||||||
np.ndarray,
|
|
||||||
np.ndarray,
|
|
||||||
sparse.csr_matrix,
|
|
||||||
np.ndarray,
|
|
||||||
np.ndarray,
|
|
||||||
sparse.csr_matrix,
|
|
||||||
np.ndarray,
|
|
||||||
np.ndarray,
|
|
||||||
]:
|
|
||||||
from sklearn.datasets import load_svmlight_files
|
|
||||||
|
|
||||||
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
|
|
||||||
target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
|
|
||||||
if not os.path.exists(target):
|
|
||||||
request.urlretrieve(url=src, filename=target)
|
|
||||||
|
|
||||||
with zipfile.ZipFile(target, "r") as f:
|
|
||||||
f.extractall(path=dpath)
|
|
||||||
|
|
||||||
(
|
|
||||||
x_train,
|
|
||||||
y_train,
|
|
||||||
qid_train,
|
|
||||||
x_test,
|
|
||||||
y_test,
|
|
||||||
qid_test,
|
|
||||||
x_valid,
|
|
||||||
y_valid,
|
|
||||||
qid_valid,
|
|
||||||
) = load_svmlight_files(
|
|
||||||
(
|
|
||||||
Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
|
|
||||||
Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
|
|
||||||
Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
|
|
||||||
),
|
|
||||||
query_id=True,
|
|
||||||
zero_based=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
x_train,
|
|
||||||
y_train,
|
|
||||||
qid_train,
|
|
||||||
x_test,
|
|
||||||
y_test,
|
|
||||||
qid_test,
|
|
||||||
x_valid,
|
|
||||||
y_valid,
|
|
||||||
qid_valid,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments,too-many-locals
|
# pylint: disable=too-many-arguments,too-many-locals
|
||||||
@memory.cache
|
@memory.cache
|
||||||
def make_categorical(
|
def make_categorical(
|
||||||
@ -738,20 +606,7 @@ _unweighted_datasets_strategy = strategies.sampled_from(
|
|||||||
TestDataset(
|
TestDataset(
|
||||||
"calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae"
|
"calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae"
|
||||||
),
|
),
|
||||||
TestDataset("digits", get_digits, "multi:softmax", "mlogloss"),
|
|
||||||
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
|
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
|
||||||
TestDataset(
|
|
||||||
"mtreg",
|
|
||||||
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
|
|
||||||
"reg:squarederror",
|
|
||||||
"rmse",
|
|
||||||
),
|
|
||||||
TestDataset(
|
|
||||||
"mtreg-l1",
|
|
||||||
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
|
|
||||||
"reg:absoluteerror",
|
|
||||||
"mae",
|
|
||||||
),
|
|
||||||
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
|
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
|
||||||
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
|
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
|
||||||
TestDataset(
|
TestDataset(
|
||||||
@ -764,37 +619,71 @@ _unweighted_datasets_strategy = strategies.sampled_from(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@strategies.composite
|
def make_datasets_with_margin(
|
||||||
def _dataset_weight_margin(draw: Callable) -> TestDataset:
|
unweighted_strategy: strategies.SearchStrategy,
|
||||||
data: TestDataset = draw(_unweighted_datasets_strategy)
|
) -> Callable:
|
||||||
if draw(strategies.booleans()):
|
"""Factory function for creating strategies that generates datasets with weight and
|
||||||
data.w = draw(
|
base margin.
|
||||||
arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))
|
|
||||||
)
|
|
||||||
if draw(strategies.booleans()):
|
|
||||||
num_class = 1
|
|
||||||
if data.objective == "multi:softmax":
|
|
||||||
num_class = int(np.max(data.y) + 1)
|
|
||||||
elif data.name.startswith("mtreg"):
|
|
||||||
num_class = data.y.shape[1]
|
|
||||||
|
|
||||||
data.margin = draw(
|
"""
|
||||||
arrays(
|
|
||||||
np.float64,
|
@strategies.composite
|
||||||
(data.y.shape[0] * num_class),
|
def weight_margin(draw: Callable) -> TestDataset:
|
||||||
elements=strategies.floats(0.5, 1.0),
|
data: TestDataset = draw(unweighted_strategy)
|
||||||
|
if draw(strategies.booleans()):
|
||||||
|
data.w = draw(
|
||||||
|
arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))
|
||||||
)
|
)
|
||||||
)
|
if draw(strategies.booleans()):
|
||||||
assert data.margin is not None
|
num_class = 1
|
||||||
if num_class != 1:
|
if data.objective == "multi:softmax":
|
||||||
data.margin = data.margin.reshape(data.y.shape[0], num_class)
|
num_class = int(np.max(data.y) + 1)
|
||||||
|
elif data.name.startswith("mtreg"):
|
||||||
|
num_class = data.y.shape[1]
|
||||||
|
|
||||||
return data
|
data.margin = draw(
|
||||||
|
arrays(
|
||||||
|
np.float64,
|
||||||
|
(data.y.shape[0] * num_class),
|
||||||
|
elements=strategies.floats(0.5, 1.0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert data.margin is not None
|
||||||
|
if num_class != 1:
|
||||||
|
data.margin = data.margin.reshape(data.y.shape[0], num_class)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
return weight_margin
|
||||||
|
|
||||||
|
|
||||||
# A strategy for drawing from a set of example datasets
|
# A strategy for drawing from a set of example datasets. May add random weights to the
|
||||||
# May add random weights to the dataset
|
# dataset
|
||||||
dataset_strategy = _dataset_weight_margin()
|
dataset_strategy = make_datasets_with_margin(_unweighted_datasets_strategy)()
|
||||||
|
|
||||||
|
|
||||||
|
_unweighted_multi_datasets_strategy = strategies.sampled_from(
|
||||||
|
[
|
||||||
|
TestDataset("digits", get_digits, "multi:softmax", "mlogloss"),
|
||||||
|
TestDataset(
|
||||||
|
"mtreg",
|
||||||
|
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
|
||||||
|
"reg:squarederror",
|
||||||
|
"rmse",
|
||||||
|
),
|
||||||
|
TestDataset(
|
||||||
|
"mtreg-l1",
|
||||||
|
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
|
||||||
|
"reg:absoluteerror",
|
||||||
|
"mae",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# A strategy for drawing from a set of multi-target/multi-class datasets.
|
||||||
|
multi_dataset_strategy = make_datasets_with_margin(
|
||||||
|
_unweighted_multi_datasets_strategy
|
||||||
|
)()
|
||||||
|
|
||||||
|
|
||||||
def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool:
|
def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool:
|
||||||
|
|||||||
@ -1,13 +1,20 @@
|
|||||||
"""Utilities for data generation."""
|
"""Utilities for data generation."""
|
||||||
from typing import Any, Generator, Tuple, Union
|
import os
|
||||||
|
import zipfile
|
||||||
|
from typing import Any, Generator, List, Tuple, Union
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.random import Generator as RNG
|
from numpy.random import Generator as RNG
|
||||||
|
from scipy import sparse
|
||||||
|
|
||||||
import xgboost
|
import xgboost
|
||||||
from xgboost.data import pandas_pyarrow_mapper
|
from xgboost.data import pandas_pyarrow_mapper
|
||||||
|
|
||||||
|
joblib = pytest.importorskip("joblib")
|
||||||
|
memory = joblib.Memory("./cachedir", verbose=0)
|
||||||
|
|
||||||
|
|
||||||
def np_dtypes(
|
def np_dtypes(
|
||||||
n_samples: int, n_features: int
|
n_samples: int, n_features: int
|
||||||
@ -195,3 +202,141 @@ def check_inf(rng: RNG) -> None:
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="Input data contains `inf`"):
|
with pytest.raises(ValueError, match="Input data contains `inf`"):
|
||||||
xgboost.DMatrix(X, y)
|
xgboost.DMatrix(X, y)
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def get_california_housing() -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Fetch the California housing dataset from sklearn."""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
data = datasets.fetch_california_housing()
|
||||||
|
return data.data, data.target
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def get_digits() -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Fetch the digits dataset from sklearn."""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
data = datasets.load_digits()
|
||||||
|
return data.data, data.target
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def get_cancer() -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Fetch the breast cancer dataset from sklearn."""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
return datasets.load_breast_cancer(return_X_y=True)
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def get_sparse() -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Generate a sparse dataset."""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
rng = np.random.RandomState(199)
|
||||||
|
n = 2000
|
||||||
|
sparsity = 0.75
|
||||||
|
X, y = datasets.make_regression(n, random_state=rng)
|
||||||
|
flag = rng.binomial(1, sparsity, X.shape)
|
||||||
|
for i in range(X.shape[0]):
|
||||||
|
for j in range(X.shape[1]):
|
||||||
|
if flag[i, j]:
|
||||||
|
X[i, j] = np.nan
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Number of samples: 1460
|
||||||
|
Number of features: 20
|
||||||
|
Number of categorical features: 10
|
||||||
|
Number of numerical features: 10
|
||||||
|
"""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
X, y = datasets.fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
|
||||||
|
|
||||||
|
categorical_columns_subset: List[str] = [
|
||||||
|
"BldgType", # 5 cats, no nan
|
||||||
|
"GarageFinish", # 3 cats, nan
|
||||||
|
"LotConfig", # 5 cats, no nan
|
||||||
|
"Functional", # 7 cats, no nan
|
||||||
|
"MasVnrType", # 4 cats, nan
|
||||||
|
"HouseStyle", # 8 cats, no nan
|
||||||
|
"FireplaceQu", # 5 cats, nan
|
||||||
|
"ExterCond", # 5 cats, no nan
|
||||||
|
"ExterQual", # 4 cats, no nan
|
||||||
|
"PoolQC", # 3 cats, nan
|
||||||
|
]
|
||||||
|
|
||||||
|
numerical_columns_subset: List[str] = [
|
||||||
|
"3SsnPorch",
|
||||||
|
"Fireplaces",
|
||||||
|
"BsmtHalfBath",
|
||||||
|
"HalfBath",
|
||||||
|
"GarageCars",
|
||||||
|
"TotRmsAbvGrd",
|
||||||
|
"BsmtFinSF1",
|
||||||
|
"BsmtFinSF2",
|
||||||
|
"GrLivArea",
|
||||||
|
"ScreenPorch",
|
||||||
|
]
|
||||||
|
|
||||||
|
X = X[categorical_columns_subset + numerical_columns_subset]
|
||||||
|
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
@memory.cache
|
||||||
|
def get_mq2008(
|
||||||
|
dpath: str,
|
||||||
|
) -> Tuple[
|
||||||
|
sparse.csr_matrix,
|
||||||
|
np.ndarray,
|
||||||
|
np.ndarray,
|
||||||
|
sparse.csr_matrix,
|
||||||
|
np.ndarray,
|
||||||
|
np.ndarray,
|
||||||
|
sparse.csr_matrix,
|
||||||
|
np.ndarray,
|
||||||
|
np.ndarray,
|
||||||
|
]:
|
||||||
|
"""Fetch the mq2008 dataset."""
|
||||||
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
|
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
|
||||||
|
target = os.path.join(dpath, "MQ2008.zip")
|
||||||
|
if not os.path.exists(target):
|
||||||
|
request.urlretrieve(url=src, filename=target)
|
||||||
|
|
||||||
|
with zipfile.ZipFile(target, "r") as f:
|
||||||
|
f.extractall(path=dpath)
|
||||||
|
|
||||||
|
(
|
||||||
|
x_train,
|
||||||
|
y_train,
|
||||||
|
qid_train,
|
||||||
|
x_test,
|
||||||
|
y_test,
|
||||||
|
qid_test,
|
||||||
|
x_valid,
|
||||||
|
y_valid,
|
||||||
|
qid_valid,
|
||||||
|
) = datasets.load_svmlight_files(
|
||||||
|
(
|
||||||
|
os.path.join(dpath, "MQ2008/Fold1/train.txt"),
|
||||||
|
os.path.join(dpath, "MQ2008/Fold1/test.txt"),
|
||||||
|
os.path.join(dpath, "MQ2008/Fold1/vali.txt"),
|
||||||
|
),
|
||||||
|
query_id=True,
|
||||||
|
zero_based=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
x_train,
|
||||||
|
y_train,
|
||||||
|
qid_train,
|
||||||
|
x_test,
|
||||||
|
y_test,
|
||||||
|
qid_test,
|
||||||
|
x_valid,
|
||||||
|
y_valid,
|
||||||
|
qid_valid,
|
||||||
|
)
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from typing import cast
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
hypothesis = pytest.importorskip("hypothesis")
|
strategies = pytest.importorskip("hypothesis.strategies")
|
||||||
from hypothesis import strategies # pylint:disable=wrong-import-position
|
|
||||||
|
|
||||||
exact_parameter_strategy = strategies.fixed_dictionaries(
|
exact_parameter_strategy = strategies.fixed_dictionaries(
|
||||||
{
|
{
|
||||||
@ -41,6 +41,26 @@ hist_parameter_strategy = strategies.fixed_dictionaries(
|
|||||||
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
|
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hist_multi_parameter_strategy = strategies.fixed_dictionaries(
|
||||||
|
{
|
||||||
|
"max_depth": strategies.integers(1, 11),
|
||||||
|
"max_leaves": strategies.integers(0, 1024),
|
||||||
|
"max_bin": strategies.integers(2, 512),
|
||||||
|
"multi_strategy": strategies.sampled_from(
|
||||||
|
["multi_output_tree", "one_output_per_tree"]
|
||||||
|
),
|
||||||
|
"grow_policy": strategies.sampled_from(["lossguide", "depthwise"]),
|
||||||
|
"min_child_weight": strategies.floats(0.5, 2.0),
|
||||||
|
# We cannot enable subsampling as the training loss can increase
|
||||||
|
# 'subsample': strategies.floats(0.5, 1.0),
|
||||||
|
"colsample_bytree": strategies.floats(0.5, 1.0),
|
||||||
|
"colsample_bylevel": strategies.floats(0.5, 1.0),
|
||||||
|
}
|
||||||
|
).filter(
|
||||||
|
lambda x: (cast(int, x["max_depth"]) > 0 or cast(int, x["max_leaves"]) > 0)
|
||||||
|
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
|
||||||
|
)
|
||||||
|
|
||||||
cat_parameter_strategy = strategies.fixed_dictionaries(
|
cat_parameter_strategy = strategies.fixed_dictionaries(
|
||||||
{
|
{
|
||||||
"max_cat_to_onehot": strategies.integers(1, 128),
|
"max_cat_to_onehot": strategies.integers(1, 128),
|
||||||
|
|||||||
@ -55,6 +55,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
|
|||||||
*out_dim = 2;
|
*out_dim = 2;
|
||||||
shape.resize(*out_dim);
|
shape.resize(*out_dim);
|
||||||
shape.front() = rows;
|
shape.front() = rows;
|
||||||
|
// chunksize can be 1 if it's softmax
|
||||||
shape.back() = std::min(groups, chunksize);
|
shape.back() = std::min(groups, chunksize);
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|||||||
@ -359,6 +359,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
|
|||||||
HistogramCuts *cuts) {
|
HistogramCuts *cuts) {
|
||||||
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
|
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
|
||||||
auto &cut_values = cuts->cut_values_.HostVector();
|
auto &cut_values = cuts->cut_values_.HostVector();
|
||||||
|
// we use the min_value as the first (0th) element, hence starting from 1.
|
||||||
for (size_t i = 1; i < required_cuts; ++i) {
|
for (size_t i = 1; i < required_cuts; ++i) {
|
||||||
bst_float cpt = summary.data[i].value;
|
bst_float cpt = summary.data[i].value;
|
||||||
if (i == 1 || cpt > cut_values.back()) {
|
if (i == 1 || cpt > cut_values.back()) {
|
||||||
@ -419,8 +420,8 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
|
|||||||
} else {
|
} else {
|
||||||
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
|
||||||
// push a value that is greater than anything
|
// push a value that is greater than anything
|
||||||
const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value
|
const bst_float cpt =
|
||||||
: cuts->min_vals_.HostVector()[fid];
|
(a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
|
||||||
// this must be bigger than last value in a scale
|
// this must be bigger than last value in a scale
|
||||||
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
|
||||||
cuts->cut_values_.HostVector().push_back(last);
|
cuts->cut_values_.HostVector().push_back(last);
|
||||||
|
|||||||
@ -352,19 +352,6 @@ struct WQSummary {
|
|||||||
prev_rmax = data[i].rmax;
|
prev_rmax = data[i].rmax;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// check consistency of the summary
|
|
||||||
inline bool Check(const char *msg) const {
|
|
||||||
const float tol = 10.0f;
|
|
||||||
for (size_t i = 0; i < this->size; ++i) {
|
|
||||||
if (data[i].rmin + data[i].wmin > data[i].rmax + tol ||
|
|
||||||
data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) {
|
|
||||||
LOG(INFO) << "---------- WQSummary::Check did not pass ----------";
|
|
||||||
this->Print();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/*! \brief try to do efficient pruning */
|
/*! \brief try to do efficient pruning */
|
||||||
|
|||||||
@ -257,6 +257,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
|||||||
}
|
}
|
||||||
iter.Reset();
|
iter.Reset();
|
||||||
CHECK_EQ(rbegin, Info().num_row_);
|
CHECK_EQ(rbegin, Info().num_row_);
|
||||||
|
CHECK_EQ(this->ghist_->Features(), Info().num_col_);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Generate column matrix
|
* Generate column matrix
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include <dmlc/parameter.h>
|
#include <dmlc/parameter.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cinttypes> // for uint32_t
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
@ -27,9 +28,11 @@
|
|||||||
#include "xgboost/host_device_vector.h"
|
#include "xgboost/host_device_vector.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
|
#include "xgboost/model.h"
|
||||||
#include "xgboost/objective.h"
|
#include "xgboost/objective.h"
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
#include "xgboost/string_view.h"
|
#include "xgboost/string_view.h" // for StringView
|
||||||
|
#include "xgboost/tree_model.h" // for RegTree
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
|
||||||
namespace xgboost::gbm {
|
namespace xgboost::gbm {
|
||||||
@ -131,6 +134,12 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
|||||||
// set, since only experts are expected to do so.
|
// set, since only experts are expected to do so.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (model_.learner_model_param->IsVectorLeaf()) {
|
||||||
|
CHECK(tparam_.tree_method == TreeMethod::kHist)
|
||||||
|
<< "Only the hist tree method is supported for building multi-target trees with vector "
|
||||||
|
"leaf.";
|
||||||
|
}
|
||||||
|
|
||||||
// tparam_ is set before calling this function.
|
// tparam_ is set before calling this function.
|
||||||
if (tparam_.tree_method != TreeMethod::kAuto) {
|
if (tparam_.tree_method != TreeMethod::kAuto) {
|
||||||
return;
|
return;
|
||||||
@ -175,12 +184,12 @@ void GBTree::ConfigureUpdaters() {
|
|||||||
case TreeMethod::kExact:
|
case TreeMethod::kExact:
|
||||||
tparam_.updater_seq = "grow_colmaker,prune";
|
tparam_.updater_seq = "grow_colmaker,prune";
|
||||||
break;
|
break;
|
||||||
case TreeMethod::kHist:
|
case TreeMethod::kHist: {
|
||||||
LOG(INFO) <<
|
LOG(INFO) << "Tree method is selected to be 'hist', which uses a single updater "
|
||||||
"Tree method is selected to be 'hist', which uses a "
|
"grow_quantile_histmaker.";
|
||||||
"single updater grow_quantile_histmaker.";
|
|
||||||
tparam_.updater_seq = "grow_quantile_histmaker";
|
tparam_.updater_seq = "grow_quantile_histmaker";
|
||||||
break;
|
break;
|
||||||
|
}
|
||||||
case TreeMethod::kGPUHist: {
|
case TreeMethod::kGPUHist: {
|
||||||
common::AssertGPUSupport();
|
common::AssertGPUSupport();
|
||||||
tparam_.updater_seq = "grow_gpu_hist";
|
tparam_.updater_seq = "grow_gpu_hist";
|
||||||
@ -209,11 +218,9 @@ void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_thre
|
|||||||
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
|
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
|
||||||
} else {
|
} else {
|
||||||
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
|
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
|
||||||
auto nsize = static_cast<bst_omp_uint>(out_gpair->Size());
|
const auto& gpair_h = in_gpair->ConstHostVector();
|
||||||
const auto &gpair_h = in_gpair->ConstHostVector();
|
common::ParallelFor(out_gpair->Size(), n_threads,
|
||||||
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
|
[&](auto i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; });
|
||||||
tmp_h[i] = gpair_h[i * n_groups + group_id];
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -234,6 +241,7 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
|
|||||||
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
|
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
|
||||||
CHECK_EQ(model_.param.num_parallel_tree, 1)
|
CHECK_EQ(model_.param.num_parallel_tree, 1)
|
||||||
<< "Boosting random forest is not supported for current objective.";
|
<< "Boosting random forest is not supported for current objective.";
|
||||||
|
CHECK(!trees.front()->IsMultiTarget()) << "Update tree leaf" << MTNotImplemented();
|
||||||
CHECK_EQ(trees.size(), model_.param.num_parallel_tree);
|
CHECK_EQ(trees.size(), model_.param.num_parallel_tree);
|
||||||
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
|
||||||
auto const& position = node_position.at(tree_idx);
|
auto const& position = node_position.at(tree_idx);
|
||||||
@ -245,17 +253,18 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
|
|||||||
void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
||||||
PredictionCacheEntry* predt, ObjFunction const* obj) {
|
PredictionCacheEntry* predt, ObjFunction const* obj) {
|
||||||
std::vector<std::vector<std::unique_ptr<RegTree>>> new_trees;
|
std::vector<std::vector<std::unique_ptr<RegTree>>> new_trees;
|
||||||
const int ngroup = model_.learner_model_param->num_output_group;
|
const int ngroup = model_.learner_model_param->OutputLength();
|
||||||
ConfigureWithKnownData(this->cfg_, p_fmat);
|
ConfigureWithKnownData(this->cfg_, p_fmat);
|
||||||
monitor_.Start("BoostNewTrees");
|
monitor_.Start("BoostNewTrees");
|
||||||
|
|
||||||
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
|
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
|
||||||
// `gpu_id` be the single source of determining what algorithms to run, but that will
|
// `gpu_id` be the single source of determining what algorithms to run, but that will
|
||||||
// break a lots of existing code.
|
// break a lots of existing code.
|
||||||
auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id;
|
auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id;
|
||||||
auto out = linalg::TensorView<float, 2>{
|
auto out = linalg::MakeTensorView(
|
||||||
|
device,
|
||||||
device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(),
|
device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(),
|
||||||
{static_cast<size_t>(p_fmat->Info().num_row_), static_cast<size_t>(ngroup)},
|
p_fmat->Info().num_row_, model_.learner_model_param->OutputLength());
|
||||||
device};
|
|
||||||
CHECK_NE(ngroup, 0);
|
CHECK_NE(ngroup, 0);
|
||||||
|
|
||||||
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
|
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
|
||||||
@ -266,7 +275,13 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
|
|||||||
// position is negated if the row is sampled out.
|
// position is negated if the row is sampled out.
|
||||||
std::vector<HostDeviceVector<bst_node_t>> node_position;
|
std::vector<HostDeviceVector<bst_node_t>> node_position;
|
||||||
|
|
||||||
if (ngroup == 1) {
|
if (model_.learner_model_param->IsVectorLeaf()) {
|
||||||
|
std::vector<std::unique_ptr<RegTree>> ret;
|
||||||
|
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
|
||||||
|
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
|
||||||
|
// No update prediction cache yet.
|
||||||
|
new_trees.push_back(std::move(ret));
|
||||||
|
} else if (model_.learner_model_param->OutputLength() == 1) {
|
||||||
std::vector<std::unique_ptr<RegTree>> ret;
|
std::vector<std::unique_ptr<RegTree>> ret;
|
||||||
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
|
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
|
||||||
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
|
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
|
||||||
@ -383,11 +398,15 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
|
|||||||
}
|
}
|
||||||
|
|
||||||
// update the trees
|
// update the trees
|
||||||
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
|
auto n_out = model_.learner_model_param->OutputLength() * p_fmat->Info().num_row_;
|
||||||
<< "Mismatching size between number of rows from input data and size of "
|
StringView msg{
|
||||||
"gradient vector.";
|
"Mismatching size between number of rows from input data and size of gradient vector."};
|
||||||
|
if (!model_.learner_model_param->IsVectorLeaf() && p_fmat->Info().num_row_ != 0) {
|
||||||
|
CHECK_EQ(n_out % gpair->Size(), 0) << msg;
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(gpair->Size(), n_out) << msg;
|
||||||
|
}
|
||||||
|
|
||||||
CHECK(out_position);
|
|
||||||
out_position->resize(new_trees.size());
|
out_position->resize(new_trees.size());
|
||||||
|
|
||||||
// Rescale learning rate according to the size of trees
|
// Rescale learning rate according to the size of trees
|
||||||
@ -402,8 +421,12 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
|
|||||||
|
|
||||||
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
|
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
|
||||||
monitor_.Start("CommitModel");
|
monitor_.Start("CommitModel");
|
||||||
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
|
if (this->model_.learner_model_param->IsVectorLeaf()) {
|
||||||
model_.CommitModel(std::move(new_trees[gid]), gid);
|
model_.CommitModel(std::move(new_trees[0]), 0);
|
||||||
|
} else {
|
||||||
|
for (std::uint32_t gid = 0; gid < model_.learner_model_param->OutputLength(); ++gid) {
|
||||||
|
model_.CommitModel(std::move(new_trees[gid]), gid);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
monitor_.Stop("CommitModel");
|
monitor_.Stop("CommitModel");
|
||||||
}
|
}
|
||||||
@ -564,11 +587,10 @@ void GBTree::PredictBatch(DMatrix* p_fmat,
|
|||||||
if (out_preds->version == 0) {
|
if (out_preds->version == 0) {
|
||||||
// out_preds->Size() can be non-zero as it's initialized here before any
|
// out_preds->Size() can be non-zero as it's initialized here before any
|
||||||
// tree is built at the 0^th iterator.
|
// tree is built at the 0^th iterator.
|
||||||
predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions,
|
predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, model_);
|
||||||
model_);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t tree_begin, tree_end;
|
std::uint32_t tree_begin, tree_end;
|
||||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
|
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||||
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
|
||||||
if (tree_end > tree_begin) {
|
if (tree_end > tree_begin) {
|
||||||
@ -577,7 +599,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat,
|
|||||||
if (reset) {
|
if (reset) {
|
||||||
out_preds->version = 0;
|
out_preds->version = 0;
|
||||||
} else {
|
} else {
|
||||||
uint32_t delta = layer_end - out_preds->version;
|
std::uint32_t delta = layer_end - out_preds->version;
|
||||||
out_preds->Update(delta);
|
out_preds->Update(delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -770,6 +792,7 @@ class Dart : public GBTree {
|
|||||||
void PredictBatchImpl(DMatrix *p_fmat, PredictionCacheEntry *p_out_preds,
|
void PredictBatchImpl(DMatrix *p_fmat, PredictionCacheEntry *p_out_preds,
|
||||||
bool training, unsigned layer_begin,
|
bool training, unsigned layer_begin,
|
||||||
unsigned layer_end) const {
|
unsigned layer_end) const {
|
||||||
|
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
|
||||||
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
|
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
|
||||||
CHECK(predictor);
|
CHECK(predictor);
|
||||||
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||||
@ -830,6 +853,7 @@ class Dart : public GBTree {
|
|||||||
void InplacePredict(std::shared_ptr<DMatrix> p_fmat, float missing,
|
void InplacePredict(std::shared_ptr<DMatrix> p_fmat, float missing,
|
||||||
PredictionCacheEntry* p_out_preds, uint32_t layer_begin,
|
PredictionCacheEntry* p_out_preds, uint32_t layer_begin,
|
||||||
unsigned layer_end) const override {
|
unsigned layer_end) const override {
|
||||||
|
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
|
||||||
uint32_t tree_begin, tree_end;
|
uint32_t tree_begin, tree_end;
|
||||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
|
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
|
||||||
auto n_groups = model_.learner_model_param->num_output_group;
|
auto n_groups = model_.learner_model_param->num_output_group;
|
||||||
|
|||||||
@ -139,14 +139,22 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
|
|||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
// From here on, layer becomes concrete trees.
|
// From here on, layer becomes concrete trees.
|
||||||
inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const &model,
|
inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const& model,
|
||||||
size_t layer_begin,
|
std::uint32_t layer_begin,
|
||||||
size_t layer_end) {
|
std::uint32_t layer_end) {
|
||||||
bst_group_t groups = model.learner_model_param->num_output_group;
|
std::uint32_t tree_begin;
|
||||||
uint32_t tree_begin = layer_begin * groups * model.param.num_parallel_tree;
|
std::uint32_t tree_end;
|
||||||
uint32_t tree_end = layer_end * groups * model.param.num_parallel_tree;
|
if (model.learner_model_param->IsVectorLeaf()) {
|
||||||
|
tree_begin = layer_begin * model.param.num_parallel_tree;
|
||||||
|
tree_end = layer_end * model.param.num_parallel_tree;
|
||||||
|
} else {
|
||||||
|
bst_group_t groups = model.learner_model_param->OutputLength();
|
||||||
|
tree_begin = layer_begin * groups * model.param.num_parallel_tree;
|
||||||
|
tree_end = layer_end * groups * model.param.num_parallel_tree;
|
||||||
|
}
|
||||||
|
|
||||||
if (tree_end == 0) {
|
if (tree_end == 0) {
|
||||||
tree_end = static_cast<uint32_t>(model.trees.size());
|
tree_end = model.trees.size();
|
||||||
}
|
}
|
||||||
if (model.trees.size() != 0) {
|
if (model.trees.size() != 0) {
|
||||||
CHECK_LE(tree_begin, tree_end);
|
CHECK_LE(tree_begin, tree_end);
|
||||||
@ -234,22 +242,25 @@ class GBTree : public GradientBooster {
|
|||||||
void LoadModel(Json const& in) override;
|
void LoadModel(Json const& in) override;
|
||||||
|
|
||||||
// Number of trees per layer.
|
// Number of trees per layer.
|
||||||
auto LayerTrees() const {
|
[[nodiscard]] std::uint32_t LayerTrees() const {
|
||||||
auto n_trees = model_.learner_model_param->num_output_group * model_.param.num_parallel_tree;
|
if (model_.learner_model_param->IsVectorLeaf()) {
|
||||||
return n_trees;
|
return model_.param.num_parallel_tree;
|
||||||
|
}
|
||||||
|
return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength();
|
||||||
}
|
}
|
||||||
|
|
||||||
// slice the trees, out must be already allocated
|
// slice the trees, out must be already allocated
|
||||||
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
|
||||||
GradientBooster *out, bool* out_of_bound) const override;
|
GradientBooster *out, bool* out_of_bound) const override;
|
||||||
|
|
||||||
int32_t BoostedRounds() const override {
|
[[nodiscard]] std::int32_t BoostedRounds() const override {
|
||||||
CHECK_NE(model_.param.num_parallel_tree, 0);
|
CHECK_NE(model_.param.num_parallel_tree, 0);
|
||||||
CHECK_NE(model_.learner_model_param->num_output_group, 0);
|
CHECK_NE(model_.learner_model_param->num_output_group, 0);
|
||||||
|
|
||||||
return model_.trees.size() / this->LayerTrees();
|
return model_.trees.size() / this->LayerTrees();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ModelFitted() const override {
|
[[nodiscard]] bool ModelFitted() const override {
|
||||||
return !model_.trees.empty() || !model_.trees_to_update.empty();
|
return !model_.trees.empty() || !model_.trees_to_update.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -326,7 +326,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
|||||||
std::string booster;
|
std::string booster;
|
||||||
std::string objective;
|
std::string objective;
|
||||||
// This is a training parameter and is not saved (nor loaded) in the model.
|
// This is a training parameter and is not saved (nor loaded) in the model.
|
||||||
MultiStrategy multi_strategy{MultiStrategy::kComposite};
|
MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree};
|
||||||
|
|
||||||
// declare parameters
|
// declare parameters
|
||||||
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
|
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
|
||||||
@ -339,12 +339,12 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
|
|||||||
.set_default("reg:squarederror")
|
.set_default("reg:squarederror")
|
||||||
.describe("Objective function used for obtaining gradient.");
|
.describe("Objective function used for obtaining gradient.");
|
||||||
DMLC_DECLARE_FIELD(multi_strategy)
|
DMLC_DECLARE_FIELD(multi_strategy)
|
||||||
.add_enum("composite", MultiStrategy::kComposite)
|
.add_enum("one_output_per_tree", MultiStrategy::kOneOutputPerTree)
|
||||||
.add_enum("monolithic", MultiStrategy::kMonolithic)
|
.add_enum("multi_output_tree", MultiStrategy::kMultiOutputTree)
|
||||||
.set_default(MultiStrategy::kComposite)
|
.set_default(MultiStrategy::kOneOutputPerTree)
|
||||||
.describe(
|
.describe(
|
||||||
"Strategy used for training multi-target models. `monolithic` means building one "
|
"Strategy used for training multi-target models. `multi_output_tree` means building "
|
||||||
"single tree for all targets.");
|
"one single tree for all targets.");
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -145,7 +145,6 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
|
|||||||
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
|
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
|
||||||
|
|
||||||
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||||
auto n_groups = info.group_ptr_.size() - 1;
|
|
||||||
|
|
||||||
auto d_inv_idcg = p_cache->InvIDCG(ctx);
|
auto d_inv_idcg = p_cache->InvIDCG(ctx);
|
||||||
auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values());
|
auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values());
|
||||||
@ -171,7 +170,6 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
|
|||||||
HostDeviceVector<float> const &predt, bool minus,
|
HostDeviceVector<float> const &predt, bool minus,
|
||||||
std::shared_ptr<ltr::MAPCache> p_cache) {
|
std::shared_ptr<ltr::MAPCache> p_cache) {
|
||||||
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
|
||||||
auto n_groups = info.group_ptr_.size() - 1;
|
|
||||||
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
|
||||||
|
|
||||||
predt.SetDevice(ctx->gpu_id);
|
predt.SetDevice(ctx->gpu_id);
|
||||||
|
|||||||
@ -87,30 +87,6 @@ bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
|
|||||||
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
|
||||||
return tree[leaf].LeafValue();
|
return tree[leaf].LeafValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
|
||||||
const size_t tree_end, const size_t predict_offset,
|
|
||||||
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
|
|
||||||
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
|
|
||||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
|
||||||
const size_t gid = model.tree_info[tree_id];
|
|
||||||
auto const &tree = *model.trees[tree_id];
|
|
||||||
auto const &cats = tree.GetCategoriesMatrix();
|
|
||||||
auto has_categorical = tree.HasCategoricalSplit();
|
|
||||||
|
|
||||||
if (has_categorical) {
|
|
||||||
for (std::size_t i = 0; i < block_size; ++i) {
|
|
||||||
out_predt(predict_offset + i, gid) +=
|
|
||||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (std::size_t i = 0; i < block_size; ++i) {
|
|
||||||
out_predt(predict_offset + i, gid) +=
|
|
||||||
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} // namespace scalar
|
} // namespace scalar
|
||||||
|
|
||||||
namespace multi {
|
namespace multi {
|
||||||
@ -128,7 +104,7 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <bool has_categorical>
|
template <bool has_categorical>
|
||||||
void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree,
|
void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tree,
|
||||||
RegTree::CategoricalSplitMatrix const &cats,
|
RegTree::CategoricalSplitMatrix const &cats,
|
||||||
linalg::VectorView<float> out_predt) {
|
linalg::VectorView<float> out_predt) {
|
||||||
bst_node_t const leaf = p_feats.HasMissing()
|
bst_node_t const leaf = p_feats.HasMissing()
|
||||||
@ -140,36 +116,52 @@ void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tre
|
|||||||
out_predt(i) += leaf_value(i);
|
out_predt(i) += leaf_value(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} // namespace multi
|
||||||
|
|
||||||
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
|
namespace {
|
||||||
const size_t tree_end, const size_t predict_offset,
|
void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_begin,
|
||||||
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
|
std::uint32_t const tree_end, std::size_t const predict_offset,
|
||||||
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
|
std::vector<RegTree::FVec> const &thread_temp, std::size_t const offset,
|
||||||
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
std::size_t const block_size, linalg::MatrixView<float> out_predt) {
|
||||||
|
for (std::uint32_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
|
||||||
auto const &tree = *model.trees.at(tree_id);
|
auto const &tree = *model.trees.at(tree_id);
|
||||||
auto cats = tree.GetCategoriesMatrix();
|
auto const &cats = tree.GetCategoriesMatrix();
|
||||||
bool has_categorical = tree.HasCategoricalSplit();
|
bool has_categorical = tree.HasCategoricalSplit();
|
||||||
|
|
||||||
if (has_categorical) {
|
if (tree.IsMultiTarget()) {
|
||||||
for (std::size_t i = 0; i < block_size; ++i) {
|
if (has_categorical) {
|
||||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
for (std::size_t i = 0; i < block_size; ++i) {
|
||||||
PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||||
t_predts);
|
multi::PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
||||||
|
t_predts);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (std::size_t i = 0; i < block_size; ++i) {
|
||||||
|
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
||||||
|
multi::PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(),
|
||||||
|
cats, t_predts);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (std::size_t i = 0; i < block_size; ++i) {
|
auto const gid = model.tree_info[tree_id];
|
||||||
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
|
if (has_categorical) {
|
||||||
PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
|
for (std::size_t i = 0; i < block_size; ++i) {
|
||||||
t_predts);
|
out_predt(predict_offset + i, gid) +=
|
||||||
|
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (std::size_t i = 0; i < block_size; ++i) {
|
||||||
|
out_predt(predict_offset + i, gid) +=
|
||||||
|
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace multi
|
|
||||||
|
|
||||||
template <typename DataView>
|
template <typename DataView>
|
||||||
void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature,
|
void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature,
|
||||||
DataView* batch, const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
DataView *batch, const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
|
||||||
for (size_t i = 0; i < block_size; ++i) {
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
||||||
if (feats.Size() == 0) {
|
if (feats.Size() == 0) {
|
||||||
@ -181,8 +173,8 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename DataView>
|
template <typename DataView>
|
||||||
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch,
|
void FVecDrop(const size_t block_size, const size_t batch_offset, DataView *batch,
|
||||||
const size_t fvec_offset, std::vector<RegTree::FVec>* p_feats) {
|
const size_t fvec_offset, std::vector<RegTree::FVec> *p_feats) {
|
||||||
for (size_t i = 0; i < block_size; ++i) {
|
for (size_t i = 0; i < block_size; ++i) {
|
||||||
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
RegTree::FVec &feats = (*p_feats)[fvec_offset + i];
|
||||||
const SparsePage::Inst inst = (*batch)[batch_offset + i];
|
const SparsePage::Inst inst = (*batch)[batch_offset + i];
|
||||||
@ -190,9 +182,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
static std::size_t constexpr kUnroll = 8;
|
static std::size_t constexpr kUnroll = 8;
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
struct SparsePageView {
|
struct SparsePageView {
|
||||||
bst_row_t base_rowid;
|
bst_row_t base_rowid;
|
||||||
@ -292,7 +282,7 @@ class AdapterView {
|
|||||||
|
|
||||||
template <typename DataView, size_t block_of_rows_size>
|
template <typename DataView, size_t block_of_rows_size>
|
||||||
void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model,
|
void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model,
|
||||||
int32_t tree_begin, int32_t tree_end,
|
std::uint32_t tree_begin, std::uint32_t tree_end,
|
||||||
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads,
|
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads,
|
||||||
linalg::TensorView<float, 2> out_predt) {
|
linalg::TensorView<float, 2> out_predt) {
|
||||||
auto &thread_temp = *p_thread_temp;
|
auto &thread_temp = *p_thread_temp;
|
||||||
@ -310,14 +300,8 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod
|
|||||||
|
|
||||||
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
|
||||||
// process block of rows through all trees to keep cache locality
|
// process block of rows through all trees to keep cache locality
|
||||||
if (model.learner_model_param->IsVectorLeaf()) {
|
PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp,
|
||||||
multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
|
fvec_offset, block_size, out_predt);
|
||||||
thread_temp, fvec_offset, block_size, out_predt);
|
|
||||||
} else {
|
|
||||||
scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
|
|
||||||
thread_temp, fvec_offset, block_size, out_predt);
|
|
||||||
}
|
|
||||||
|
|
||||||
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -348,7 +332,6 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
|
|||||||
FillNodeMeanValues(tree, 0, mean_values);
|
FillNodeMeanValues(tree, 0, mean_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
// init thread buffers
|
// init thread buffers
|
||||||
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
|
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
|
||||||
int prev_thread_temp_size = out->size();
|
int prev_thread_temp_size = out->size();
|
||||||
|
|||||||
@ -411,7 +411,7 @@ class DeviceModel {
|
|||||||
|
|
||||||
this->tree_beg_ = tree_begin;
|
this->tree_beg_ = tree_begin;
|
||||||
this->tree_end_ = tree_end;
|
this->tree_end_ = tree_end;
|
||||||
this->num_group = model.learner_model_param->num_output_group;
|
this->num_group = model.learner_model_param->OutputLength();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -306,9 +306,9 @@ class HistogramBuilder {
|
|||||||
|
|
||||||
// Construct a work space for building histogram. Eventually we should move this
|
// Construct a work space for building histogram. Eventually we should move this
|
||||||
// function into histogram builder once hist tree method supports external memory.
|
// function into histogram builder once hist tree method supports external memory.
|
||||||
template <typename Partitioner>
|
template <typename Partitioner, typename ExpandEntry = CPUExpandEntry>
|
||||||
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
|
||||||
std::vector<CPUExpandEntry> const &nodes_to_build) {
|
std::vector<ExpandEntry> const &nodes_to_build) {
|
||||||
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
|
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
|
||||||
for (auto const &partition : partitioners) {
|
for (auto const &partition : partitioners) {
|
||||||
size_t k = 0;
|
size_t k = 0;
|
||||||
|
|||||||
@ -889,6 +889,8 @@ void RegTree::Save(dmlc::Stream* fo) const {
|
|||||||
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
|
||||||
CHECK_EQ(param_.deprecated_num_roots, 1);
|
CHECK_EQ(param_.deprecated_num_roots, 1);
|
||||||
CHECK_NE(param_.num_nodes, 0);
|
CHECK_NE(param_.num_nodes, 0);
|
||||||
|
CHECK(!IsMultiTarget())
|
||||||
|
<< "Please use JSON/UBJSON for saving models with multi-target trees.";
|
||||||
CHECK(!HasCategoricalSplit())
|
CHECK(!HasCategoricalSplit())
|
||||||
<< "Please use JSON/UBJSON for saving models with categorical splits.";
|
<< "Please use JSON/UBJSON for saving models with categorical splits.";
|
||||||
|
|
||||||
|
|||||||
@ -4,36 +4,39 @@
|
|||||||
* \brief use quantized feature values to construct a tree
|
* \brief use quantized feature values to construct a tree
|
||||||
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
* \author Philip Cho, Tianqi Checn, Egor Smirnov
|
||||||
*/
|
*/
|
||||||
#include <algorithm> // for max
|
#include <algorithm> // for max, copy, transform
|
||||||
#include <cstddef> // for size_t
|
#include <cstddef> // for size_t
|
||||||
#include <cstdint> // for uint32_t
|
#include <cstdint> // for uint32_t, int32_t
|
||||||
#include <memory> // for unique_ptr, allocator, make_unique, make_shared
|
#include <memory> // for unique_ptr, allocator, make_unique, shared_ptr
|
||||||
#include <ostream> // for operator<<, char_traits, basic_ostream
|
#include <numeric> // for accumulate
|
||||||
#include <tuple> // for apply
|
#include <ostream> // for basic_ostream, char_traits, operator<<
|
||||||
#include <utility> // for move, swap
|
#include <utility> // for move, swap
|
||||||
#include <vector> // for vector
|
#include <vector> // for vector
|
||||||
|
|
||||||
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
|
||||||
#include "../collective/communicator.h" // for Operation
|
#include "../collective/communicator.h" // for Operation
|
||||||
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
|
||||||
|
#include "../common/linalg_op.h" // for begin, cbegin, cend
|
||||||
#include "../common/random.h" // for ColumnSampler
|
#include "../common/random.h" // for ColumnSampler
|
||||||
#include "../common/threading_utils.h" // for ParallelFor
|
#include "../common/threading_utils.h" // for ParallelFor
|
||||||
#include "../common/timer.h" // for Monitor
|
#include "../common/timer.h" // for Monitor
|
||||||
|
#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter
|
||||||
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
#include "../data/gradient_index.h" // for GHistIndexMatrix
|
||||||
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
#include "common_row_partitioner.h" // for CommonRowPartitioner
|
||||||
|
#include "dmlc/omp.h" // for omp_get_thread_num
|
||||||
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
|
||||||
#include "driver.h" // for Driver
|
#include "driver.h" // for Driver
|
||||||
#include "hist/evaluate_splits.h" // for HistEvaluator, UpdatePredictionCacheImpl
|
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
|
||||||
#include "hist/expand_entry.h" // for CPUExpandEntry
|
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
|
||||||
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
|
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
|
||||||
#include "hist/sampler.h" // for SampleGradient
|
#include "hist/sampler.h" // for SampleGradient
|
||||||
#include "param.h" // for TrainParam, GradStats
|
#include "param.h" // for TrainParam, SplitEntryContainer, GradStats
|
||||||
#include "xgboost/base.h" // for GradientPair, GradientPairInternal, bst_node_t
|
#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ...
|
||||||
#include "xgboost/context.h" // for Context
|
#include "xgboost/context.h" // for Context
|
||||||
#include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo
|
#include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo
|
||||||
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
#include "xgboost/host_device_vector.h" // for HostDeviceVector
|
||||||
#include "xgboost/linalg.h" // for TensorView, MatrixView, UnravelIndex, All
|
#include "xgboost/linalg.h" // for All, MatrixView, TensorView, Matrix, Empty
|
||||||
#include "xgboost/logging.h" // for LogCheck_EQ, LogCheck_GE, CHECK_EQ, LOG, LOG...
|
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE
|
||||||
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
|
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
|
||||||
#include "xgboost/string_view.h" // for operator<<
|
#include "xgboost/string_view.h" // for operator<<
|
||||||
#include "xgboost/task.h" // for ObjInfo
|
#include "xgboost/task.h" // for ObjInfo
|
||||||
@ -105,6 +108,212 @@ void UpdateTree(common::Monitor *monitor_, linalg::MatrixView<GradientPair const
|
|||||||
monitor_->Stop(__func__);
|
monitor_->Stop(__func__);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Updater for building multi-target trees. The implementation simply iterates over
|
||||||
|
* each target.
|
||||||
|
*/
|
||||||
|
class MultiTargetHistBuilder {
|
||||||
|
private:
|
||||||
|
common::Monitor *monitor_{nullptr};
|
||||||
|
TrainParam const *param_{nullptr};
|
||||||
|
std::shared_ptr<common::ColumnSampler> col_sampler_;
|
||||||
|
std::unique_ptr<HistMultiEvaluator> evaluator_;
|
||||||
|
// Histogram builder for each target.
|
||||||
|
std::vector<HistogramBuilder<MultiExpandEntry>> histogram_builder_;
|
||||||
|
Context const *ctx_{nullptr};
|
||||||
|
// Partitioner for each data batch.
|
||||||
|
std::vector<CommonRowPartitioner> partitioner_;
|
||||||
|
// Pointer to last updated tree, used for update prediction cache.
|
||||||
|
RegTree const *p_last_tree_{nullptr};
|
||||||
|
|
||||||
|
ObjInfo const *task_{nullptr};
|
||||||
|
|
||||||
|
public:
|
||||||
|
void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
|
std::vector<MultiExpandEntry> const &applied) {
|
||||||
|
monitor_->Start(__func__);
|
||||||
|
std::size_t page_id{0};
|
||||||
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(this->param_))) {
|
||||||
|
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
|
||||||
|
page_id++;
|
||||||
|
}
|
||||||
|
monitor_->Stop(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
|
||||||
|
this->evaluator_->ApplyTreeSplit(candidate, p_tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
|
||||||
|
monitor_->Start(__func__);
|
||||||
|
|
||||||
|
std::size_t page_id = 0;
|
||||||
|
bst_bin_t n_total_bins = 0;
|
||||||
|
partitioner_.clear();
|
||||||
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
|
if (n_total_bins == 0) {
|
||||||
|
n_total_bins = page.cut.TotalBins();
|
||||||
|
} else {
|
||||||
|
CHECK_EQ(n_total_bins, page.cut.TotalBins());
|
||||||
|
}
|
||||||
|
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
|
||||||
|
page_id++;
|
||||||
|
}
|
||||||
|
|
||||||
|
bst_target_t n_targets = p_tree->NumTargets();
|
||||||
|
histogram_builder_.clear();
|
||||||
|
for (std::size_t i = 0; i < n_targets; ++i) {
|
||||||
|
histogram_builder_.emplace_back();
|
||||||
|
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
|
||||||
|
collective::IsDistributed(), p_fmat->IsColumnSplit());
|
||||||
|
}
|
||||||
|
|
||||||
|
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
|
||||||
|
p_last_tree_ = p_tree;
|
||||||
|
monitor_->Stop(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
MultiExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView<GradientPair const> gpair,
|
||||||
|
RegTree *p_tree) {
|
||||||
|
monitor_->Start(__func__);
|
||||||
|
MultiExpandEntry best;
|
||||||
|
best.nid = RegTree::kRoot;
|
||||||
|
best.depth = 0;
|
||||||
|
|
||||||
|
auto n_targets = p_tree->NumTargets();
|
||||||
|
linalg::Matrix<GradientPairPrecise> root_sum_tloc =
|
||||||
|
linalg::Empty<GradientPairPrecise>(ctx_, ctx_->Threads(), n_targets);
|
||||||
|
CHECK_EQ(root_sum_tloc.Shape(1), gpair.Shape(1));
|
||||||
|
auto h_root_sum_tloc = root_sum_tloc.HostView();
|
||||||
|
common::ParallelFor(gpair.Shape(0), ctx_->Threads(), [&](auto i) {
|
||||||
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
|
h_root_sum_tloc(omp_get_thread_num(), t) += GradientPairPrecise{gpair(i, t)};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// Aggregate to the first row.
|
||||||
|
auto root_sum = h_root_sum_tloc.Slice(0, linalg::All());
|
||||||
|
for (std::int32_t tidx{1}; tidx < ctx_->Threads(); ++tidx) {
|
||||||
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
|
root_sum(t) += h_root_sum_tloc(tidx, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK(root_sum.CContiguous());
|
||||||
|
collective::Allreduce<collective::Operation::kSum>(
|
||||||
|
reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2);
|
||||||
|
|
||||||
|
std::vector<MultiExpandEntry> nodes{best};
|
||||||
|
std::size_t i = 0;
|
||||||
|
auto space = ConstructHistSpace(partitioner_, nodes);
|
||||||
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
|
for (bst_target_t t{0}; t < n_targets; ++t) {
|
||||||
|
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||||
|
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||||
|
nodes, {}, t_gpair.Values());
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto weight = evaluator_->InitRoot(root_sum);
|
||||||
|
auto weight_t = weight.HostView();
|
||||||
|
std::transform(linalg::cbegin(weight_t), linalg::cend(weight_t), linalg::begin(weight_t),
|
||||||
|
[&](float w) { return w * param_->learning_rate; });
|
||||||
|
|
||||||
|
p_tree->SetLeaf(RegTree::kRoot, weight_t);
|
||||||
|
std::vector<common::HistCollection const *> hists;
|
||||||
|
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||||
|
hists.push_back(&histogram_builder_[t].Histogram());
|
||||||
|
}
|
||||||
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
|
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
monitor_->Stop(__func__);
|
||||||
|
|
||||||
|
return nodes.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildHistogram(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
|
std::vector<MultiExpandEntry> const &valid_candidates,
|
||||||
|
linalg::MatrixView<GradientPair const> gpair) {
|
||||||
|
monitor_->Start(__func__);
|
||||||
|
std::vector<MultiExpandEntry> nodes_to_build;
|
||||||
|
std::vector<MultiExpandEntry> nodes_to_sub;
|
||||||
|
|
||||||
|
for (auto const &c : valid_candidates) {
|
||||||
|
auto left_nidx = p_tree->LeftChild(c.nid);
|
||||||
|
auto right_nidx = p_tree->RightChild(c.nid);
|
||||||
|
|
||||||
|
auto build_nidx = left_nidx;
|
||||||
|
auto subtract_nidx = right_nidx;
|
||||||
|
auto lit =
|
||||||
|
common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); });
|
||||||
|
auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0);
|
||||||
|
auto rit =
|
||||||
|
common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); });
|
||||||
|
auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0);
|
||||||
|
auto fewer_right = right_sum < left_sum;
|
||||||
|
if (fewer_right) {
|
||||||
|
std::swap(build_nidx, subtract_nidx);
|
||||||
|
}
|
||||||
|
nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx));
|
||||||
|
nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t i = 0;
|
||||||
|
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||||
|
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
|
for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) {
|
||||||
|
auto t_gpair = gpair.Slice(linalg::All(), t);
|
||||||
|
// Make sure the gradient matrix is f-order.
|
||||||
|
CHECK(t_gpair.Contiguous());
|
||||||
|
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
|
||||||
|
nodes_to_build, nodes_to_sub, t_gpair.Values());
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
monitor_->Stop(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
|
||||||
|
std::vector<MultiExpandEntry> *best_splits) {
|
||||||
|
monitor_->Start(__func__);
|
||||||
|
std::vector<common::HistCollection const *> hists;
|
||||||
|
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
|
||||||
|
hists.push_back(&histogram_builder_[t].Histogram());
|
||||||
|
}
|
||||||
|
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
|
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
monitor_->Stop(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
void LeafPartition(RegTree const &tree, linalg::MatrixView<GradientPair const> gpair,
|
||||||
|
std::vector<bst_node_t> *p_out_position) {
|
||||||
|
monitor_->Start(__func__);
|
||||||
|
if (!task_->UpdateTreeLeaf()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto const &part : partitioner_) {
|
||||||
|
part.LeafPartition(ctx_, tree, gpair, p_out_position);
|
||||||
|
}
|
||||||
|
monitor_->Stop(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param,
|
||||||
|
std::shared_ptr<common::ColumnSampler> column_sampler,
|
||||||
|
ObjInfo const *task, common::Monitor *monitor)
|
||||||
|
: monitor_{monitor},
|
||||||
|
param_{param},
|
||||||
|
col_sampler_{std::move(column_sampler)},
|
||||||
|
evaluator_{std::make_unique<HistMultiEvaluator>(ctx, info, param, col_sampler_)},
|
||||||
|
ctx_{ctx},
|
||||||
|
task_{task} {
|
||||||
|
monitor_->Init(__func__);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class HistBuilder {
|
class HistBuilder {
|
||||||
private:
|
private:
|
||||||
common::Monitor *monitor_;
|
common::Monitor *monitor_;
|
||||||
@ -155,8 +364,7 @@ class HistBuilder {
|
|||||||
// initialize temp data structure
|
// initialize temp data structure
|
||||||
void InitData(DMatrix *fmat, RegTree const *p_tree) {
|
void InitData(DMatrix *fmat, RegTree const *p_tree) {
|
||||||
monitor_->Start(__func__);
|
monitor_->Start(__func__);
|
||||||
|
std::size_t page_id{0};
|
||||||
size_t page_id{0};
|
|
||||||
bst_bin_t n_total_bins{0};
|
bst_bin_t n_total_bins{0};
|
||||||
partitioner_.clear();
|
partitioner_.clear();
|
||||||
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
@ -195,7 +403,7 @@ class HistBuilder {
|
|||||||
RegTree *p_tree) {
|
RegTree *p_tree) {
|
||||||
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
|
||||||
|
|
||||||
size_t page_id = 0;
|
std::size_t page_id = 0;
|
||||||
auto space = ConstructHistSpace(partitioner_, {node});
|
auto space = ConstructHistSpace(partitioner_, {node});
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
std::vector<CPUExpandEntry> nodes_to_build{node};
|
std::vector<CPUExpandEntry> nodes_to_build{node};
|
||||||
@ -214,13 +422,13 @@ class HistBuilder {
|
|||||||
* of gradient histogram is equal to snode[nid]
|
* of gradient histogram is equal to snode[nid]
|
||||||
*/
|
*/
|
||||||
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
|
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
|
||||||
std::vector<uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
std::vector<std::uint32_t> const &row_ptr = gmat.cut.Ptrs();
|
||||||
CHECK_GE(row_ptr.size(), 2);
|
CHECK_GE(row_ptr.size(), 2);
|
||||||
uint32_t const ibegin = row_ptr[0];
|
std::uint32_t const ibegin = row_ptr[0];
|
||||||
uint32_t const iend = row_ptr[1];
|
std::uint32_t const iend = row_ptr[1];
|
||||||
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
|
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
|
||||||
auto begin = hist.data();
|
auto begin = hist.data();
|
||||||
for (uint32_t i = ibegin; i < iend; ++i) {
|
for (std::uint32_t i = ibegin; i < iend; ++i) {
|
||||||
GradientPairPrecise const &et = begin[i];
|
GradientPairPrecise const &et = begin[i];
|
||||||
grad_stat.Add(et.GetGrad(), et.GetHess());
|
grad_stat.Add(et.GetGrad(), et.GetHess());
|
||||||
}
|
}
|
||||||
@ -259,7 +467,7 @@ class HistBuilder {
|
|||||||
std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
|
std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
|
||||||
std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
|
std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
|
||||||
|
|
||||||
size_t n_idx = 0;
|
std::size_t n_idx = 0;
|
||||||
for (auto const &c : valid_candidates) {
|
for (auto const &c : valid_candidates) {
|
||||||
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
auto left_nidx = (*p_tree)[c.nid].LeftChild();
|
||||||
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
auto right_nidx = (*p_tree)[c.nid].RightChild();
|
||||||
@ -275,7 +483,7 @@ class HistBuilder {
|
|||||||
n_idx++;
|
n_idx++;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t page_id{0};
|
std::size_t page_id{0};
|
||||||
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
|
||||||
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
|
||||||
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
|
||||||
@ -311,11 +519,12 @@ class HistBuilder {
|
|||||||
|
|
||||||
/*! \brief construct a tree using quantized feature values */
|
/*! \brief construct a tree using quantized feature values */
|
||||||
class QuantileHistMaker : public TreeUpdater {
|
class QuantileHistMaker : public TreeUpdater {
|
||||||
std::unique_ptr<HistBuilder> p_impl_;
|
std::unique_ptr<HistBuilder> p_impl_{nullptr};
|
||||||
|
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
|
||||||
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
std::shared_ptr<common::ColumnSampler> column_sampler_ =
|
||||||
std::make_shared<common::ColumnSampler>();
|
std::make_shared<common::ColumnSampler>();
|
||||||
common::Monitor monitor_;
|
common::Monitor monitor_;
|
||||||
ObjInfo const *task_;
|
ObjInfo const *task_{nullptr};
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
|
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
|
||||||
@ -332,7 +541,10 @@ class QuantileHistMaker : public TreeUpdater {
|
|||||||
const std::vector<RegTree *> &trees) override {
|
const std::vector<RegTree *> &trees) override {
|
||||||
if (trees.front()->IsMultiTarget()) {
|
if (trees.front()->IsMultiTarget()) {
|
||||||
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
|
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
|
||||||
LOG(FATAL) << "Not implemented.";
|
if (!p_mtimpl_) {
|
||||||
|
this->p_mtimpl_ = std::make_unique<MultiTargetHistBuilder>(
|
||||||
|
ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!p_impl_) {
|
if (!p_impl_) {
|
||||||
p_impl_ =
|
p_impl_ =
|
||||||
@ -355,13 +567,14 @@ class QuantileHistMaker : public TreeUpdater {
|
|||||||
|
|
||||||
for (auto tree_it = trees.begin(); tree_it != trees.end(); ++tree_it) {
|
for (auto tree_it = trees.begin(); tree_it != trees.end(); ++tree_it) {
|
||||||
if (need_copy()) {
|
if (need_copy()) {
|
||||||
// Copy gradient into buffer for sampling.
|
// Copy gradient into buffer for sampling. This converts C-order to F-order.
|
||||||
std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out));
|
std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out));
|
||||||
}
|
}
|
||||||
SampleGradient(ctx_, *param, h_sample_out);
|
SampleGradient(ctx_, *param, h_sample_out);
|
||||||
auto *h_out_position = &out_position[tree_it - trees.begin()];
|
auto *h_out_position = &out_position[tree_it - trees.begin()];
|
||||||
if ((*tree_it)->IsMultiTarget()) {
|
if ((*tree_it)->IsMultiTarget()) {
|
||||||
LOG(FATAL) << "Not implemented.";
|
UpdateTree<MultiExpandEntry>(&monitor_, h_sample_out, p_mtimpl_.get(), p_fmat, param,
|
||||||
|
h_out_position, *tree_it);
|
||||||
} else {
|
} else {
|
||||||
UpdateTree<CPUExpandEntry>(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param,
|
UpdateTree<CPUExpandEntry>(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param,
|
||||||
h_out_position, *tree_it);
|
h_out_position, *tree_it);
|
||||||
@ -372,6 +585,9 @@ class QuantileHistMaker : public TreeUpdater {
|
|||||||
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
|
||||||
if (p_impl_) {
|
if (p_impl_) {
|
||||||
return p_impl_->UpdatePredictionCache(data, out_preds);
|
return p_impl_->UpdatePredictionCache(data, out_preds);
|
||||||
|
} else if (p_mtimpl_) {
|
||||||
|
// Not yet supported.
|
||||||
|
return false;
|
||||||
} else {
|
} else {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -383,6 +599,6 @@ class QuantileHistMaker : public TreeUpdater {
|
|||||||
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
|
||||||
.describe("Grow tree using quantized histogram.")
|
.describe("Grow tree using quantized histogram.")
|
||||||
.set_body([](Context const *ctx, ObjInfo const *task) {
|
.set_body([](Context const *ctx, ObjInfo const *task) {
|
||||||
return new QuantileHistMaker(ctx, task);
|
return new QuantileHistMaker{ctx, task};
|
||||||
});
|
});
|
||||||
} // namespace xgboost::tree
|
} // namespace xgboost::tree
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from multiprocessing import Pool, cpu_count
|
from multiprocessing import Pool, cpu_count
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Tuple
|
||||||
|
|
||||||
from pylint import epylint
|
from pylint import epylint
|
||||||
from test_utils import PY_PACKAGE, ROOT, cd, print_time, record_time
|
from test_utils import PY_PACKAGE, ROOT, cd, print_time, record_time
|
||||||
@ -15,8 +15,11 @@ SRCPATH = os.path.normpath(
|
|||||||
|
|
||||||
|
|
||||||
@record_time
|
@record_time
|
||||||
def run_black(rel_path: str) -> bool:
|
def run_black(rel_path: str, fix: bool) -> bool:
|
||||||
cmd = ["black", "-q", "--check", rel_path]
|
if fix:
|
||||||
|
cmd = ["black", "-q", rel_path]
|
||||||
|
else:
|
||||||
|
cmd = ["black", "-q", "--check", rel_path]
|
||||||
ret = subprocess.run(cmd).returncode
|
ret = subprocess.run(cmd).returncode
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
subprocess.run(["black", "--version"])
|
subprocess.run(["black", "--version"])
|
||||||
@ -31,8 +34,11 @@ Please run the following command on your machine to address the formatting error
|
|||||||
|
|
||||||
|
|
||||||
@record_time
|
@record_time
|
||||||
def run_isort(rel_path: str) -> bool:
|
def run_isort(rel_path: str, fix: bool) -> bool:
|
||||||
cmd = ["isort", f"--src={SRCPATH}", "--check", "--profile=black", rel_path]
|
if fix:
|
||||||
|
cmd = ["isort", f"--src={SRCPATH}", "--profile=black", rel_path]
|
||||||
|
else:
|
||||||
|
cmd = ["isort", f"--src={SRCPATH}", "--check", "--profile=black", rel_path]
|
||||||
ret = subprocess.run(cmd).returncode
|
ret = subprocess.run(cmd).returncode
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
subprocess.run(["isort", "--version"])
|
subprocess.run(["isort", "--version"])
|
||||||
@ -132,7 +138,7 @@ def run_pylint() -> bool:
|
|||||||
def main(args: argparse.Namespace) -> None:
|
def main(args: argparse.Namespace) -> None:
|
||||||
if args.format == 1:
|
if args.format == 1:
|
||||||
black_results = [
|
black_results = [
|
||||||
run_black(path)
|
run_black(path, args.fix)
|
||||||
for path in [
|
for path in [
|
||||||
# core
|
# core
|
||||||
"python-package/",
|
"python-package/",
|
||||||
@ -166,7 +172,7 @@ def main(args: argparse.Namespace) -> None:
|
|||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
isort_results = [
|
isort_results = [
|
||||||
run_isort(path)
|
run_isort(path, args.fix)
|
||||||
for path in [
|
for path in [
|
||||||
# core
|
# core
|
||||||
"python-package/",
|
"python-package/",
|
||||||
@ -230,6 +236,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--format", type=int, choices=[0, 1], default=1)
|
parser.add_argument("--format", type=int, choices=[0, 1], default=1)
|
||||||
parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
|
parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
|
||||||
parser.add_argument("--pylint", type=int, choices=[0, 1], default=1)
|
parser.add_argument("--pylint", type=int, choices=[0, 1], default=1)
|
||||||
|
parser.add_argument(
|
||||||
|
"--fix",
|
||||||
|
action="store_true",
|
||||||
|
help="Fix the formatting issues instead of emitting an error.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
try:
|
try:
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -412,7 +412,7 @@ std::pair<Json, Json> TestModelSlice(std::string booster) {
|
|||||||
j++;
|
j++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK sliced model doesn't have dependency on old one
|
// CHECK sliced model doesn't have dependency on the old one
|
||||||
learner.reset();
|
learner.reset();
|
||||||
CHECK_EQ(sliced->GetNumFeature(), kCols);
|
CHECK_EQ(sliced->GetNumFeature(), kCols);
|
||||||
|
|
||||||
|
|||||||
@ -473,7 +473,7 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint
|
|||||||
int32_t device = Context::kCpuId) {
|
int32_t device = Context::kCpuId) {
|
||||||
size_t shape[1]{1};
|
size_t shape[1]{1};
|
||||||
LearnerModelParam mparam(n_features, linalg::Tensor<float, 1>{{base_score}, shape, device},
|
LearnerModelParam mparam(n_features, linalg::Tensor<float, 1>{{base_score}, shape, device},
|
||||||
n_groups, 1, MultiStrategy::kComposite);
|
n_groups, 1, MultiStrategy::kOneOutputPerTree);
|
||||||
return mparam;
|
return mparam;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -428,7 +428,7 @@ void TestVectorLeafPrediction(Context const *ctx) {
|
|||||||
|
|
||||||
LearnerModelParam mparam{static_cast<bst_feature_t>(kCols),
|
LearnerModelParam mparam{static_cast<bst_feature_t>(kCols),
|
||||||
linalg::Vector<float>{{0.5}, {1}, Context::kCpuId}, 1, 3,
|
linalg::Vector<float>{{0.5}, {1}, Context::kCpuId}, 1, 3,
|
||||||
MultiStrategy::kMonolithic};
|
MultiStrategy::kMultiOutputTree};
|
||||||
|
|
||||||
std::vector<std::unique_ptr<RegTree>> trees;
|
std::vector<std::unique_ptr<RegTree>> trees;
|
||||||
trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature});
|
trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature});
|
||||||
|
|||||||
@ -124,11 +124,11 @@ TEST(MultiStrategy, Configure) {
|
|||||||
auto p_fmat = RandomDataGenerator{12ul, 3ul, 0.0}.GenerateDMatrix();
|
auto p_fmat = RandomDataGenerator{12ul, 3ul, 0.0}.GenerateDMatrix();
|
||||||
p_fmat->Info().labels.Reshape(p_fmat->Info().num_row_, 2);
|
p_fmat->Info().labels.Reshape(p_fmat->Info().num_row_, 2);
|
||||||
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
|
||||||
learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "2"}});
|
learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "2"}});
|
||||||
learner->Configure();
|
learner->Configure();
|
||||||
ASSERT_EQ(learner->Groups(), 2);
|
ASSERT_EQ(learner->Groups(), 2);
|
||||||
|
|
||||||
learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "0"}});
|
learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "0"}});
|
||||||
ASSERT_THROW({ learner->Configure(); }, dmlc::Error);
|
ASSERT_THROW({ learner->Configure(); }, dmlc::Error);
|
||||||
}
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -116,7 +116,7 @@ def test_with_mq2008(objective, metric) -> None:
|
|||||||
x_valid,
|
x_valid,
|
||||||
y_valid,
|
y_valid,
|
||||||
qid_valid,
|
qid_valid,
|
||||||
) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
|
) = tm.data.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
|
||||||
|
|
||||||
if metric.find("map") != -1 or objective.find("map") != -1:
|
if metric.find("map") != -1 or objective.find("map") != -1:
|
||||||
y_train[y_train <= 1] = 0.0
|
y_train[y_train <= 1] = 0.0
|
||||||
|
|||||||
@ -32,6 +32,19 @@ def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TestGPUUpdatersMulti:
|
||||||
|
@given(
|
||||||
|
hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
|
||||||
|
)
|
||||||
|
@settings(deadline=None, max_examples=50, print_blob=True)
|
||||||
|
def test_hist(self, param, num_rounds, dataset):
|
||||||
|
param["tree_method"] = "gpu_hist"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
note(result)
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
|
||||||
class TestGPUUpdaters:
|
class TestGPUUpdaters:
|
||||||
cputest = test_up.TestTreeMethod()
|
cputest = test_up.TestTreeMethod()
|
||||||
|
|
||||||
@ -101,7 +114,7 @@ class TestGPUUpdaters:
|
|||||||
) -> None:
|
) -> None:
|
||||||
cat_parameters.update(hist_parameters)
|
cat_parameters.update(hist_parameters)
|
||||||
dataset = tm.TestDataset(
|
dataset = tm.TestDataset(
|
||||||
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
|
"ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse"
|
||||||
)
|
)
|
||||||
cat_parameters["tree_method"] = "gpu_hist"
|
cat_parameters["tree_method"] = "gpu_hist"
|
||||||
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
||||||
|
|||||||
@ -15,13 +15,17 @@ rng = np.random.RandomState(1994)
|
|||||||
|
|
||||||
|
|
||||||
def json_model(model_path: str, parameters: dict) -> dict:
|
def json_model(model_path: str, parameters: dict) -> dict:
|
||||||
X = np.random.random((10, 3))
|
datasets = pytest.importorskip("sklearn.datasets")
|
||||||
y = np.random.randint(2, size=(10,))
|
|
||||||
|
X, y = datasets.make_classification(64, n_features=8, n_classes=3, n_informative=6)
|
||||||
|
if parameters.get("objective", None) == "multi:softmax":
|
||||||
|
parameters["num_class"] = 3
|
||||||
|
|
||||||
dm1 = xgb.DMatrix(X, y)
|
dm1 = xgb.DMatrix(X, y)
|
||||||
|
|
||||||
bst = xgb.train(parameters, dm1)
|
bst = xgb.train(parameters, dm1)
|
||||||
bst.save_model(model_path)
|
bst.save_model(model_path)
|
||||||
|
|
||||||
if model_path.endswith("ubj"):
|
if model_path.endswith("ubj"):
|
||||||
import ubjson
|
import ubjson
|
||||||
with open(model_path, "rb") as ubjfd:
|
with open(model_path, "rb") as ubjfd:
|
||||||
@ -326,24 +330,43 @@ class TestModels:
|
|||||||
from_ubjraw = xgb.Booster()
|
from_ubjraw = xgb.Booster()
|
||||||
from_ubjraw.load_model(ubj_raw)
|
from_ubjraw.load_model(ubj_raw)
|
||||||
|
|
||||||
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
if parameters.get("multi_strategy", None) != "multi_output_tree":
|
||||||
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
# old binary model is not supported.
|
||||||
|
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
||||||
|
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
||||||
|
|
||||||
assert old_from_json == old_from_ubj
|
assert old_from_json == old_from_ubj
|
||||||
|
|
||||||
raw_json = bst.save_raw(raw_format="json")
|
raw_json = bst.save_raw(raw_format="json")
|
||||||
pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n"
|
pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n"
|
||||||
bst.load_model(bytearray(pretty, encoding="ascii"))
|
bst.load_model(bytearray(pretty, encoding="ascii"))
|
||||||
|
|
||||||
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
if parameters.get("multi_strategy", None) != "multi_output_tree":
|
||||||
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
# old binary model is not supported.
|
||||||
|
old_from_json = from_jraw.save_raw(raw_format="deprecated")
|
||||||
|
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
|
||||||
|
|
||||||
assert old_from_json == old_from_ubj
|
assert old_from_json == old_from_ubj
|
||||||
|
|
||||||
|
rng = np.random.default_rng()
|
||||||
|
X = rng.random(size=from_jraw.num_features() * 10).reshape(
|
||||||
|
(10, from_jraw.num_features())
|
||||||
|
)
|
||||||
|
predt_from_jraw = from_jraw.predict(xgb.DMatrix(X))
|
||||||
|
predt_from_bst = bst.predict(xgb.DMatrix(X))
|
||||||
|
np.testing.assert_allclose(predt_from_jraw, predt_from_bst)
|
||||||
|
|
||||||
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
@pytest.mark.parametrize("ext", ["json", "ubj"])
|
||||||
def test_model_json_io(self, ext: str) -> None:
|
def test_model_json_io(self, ext: str) -> None:
|
||||||
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
parameters = {"booster": "gbtree", "tree_method": "hist"}
|
||||||
self.run_model_json_io(parameters, ext)
|
self.run_model_json_io(parameters, ext)
|
||||||
|
parameters = {
|
||||||
|
"booster": "gbtree",
|
||||||
|
"tree_method": "hist",
|
||||||
|
"multi_strategy": "multi_output_tree",
|
||||||
|
"objective": "multi:softmax",
|
||||||
|
}
|
||||||
|
self.run_model_json_io(parameters, ext)
|
||||||
parameters = {"booster": "gblinear"}
|
parameters = {"booster": "gblinear"}
|
||||||
self.run_model_json_io(parameters, ext)
|
self.run_model_json_io(parameters, ext)
|
||||||
parameters = {"booster": "dart", "tree_method": "hist"}
|
parameters = {"booster": "dart", "tree_method": "hist"}
|
||||||
|
|||||||
@ -465,7 +465,7 @@ class TestCallbacks:
|
|||||||
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
|
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
|
||||||
|
|
||||||
def test_callback_list(self):
|
def test_callback_list(self):
|
||||||
X, y = tm.get_california_housing()
|
X, y = tm.data.get_california_housing()
|
||||||
m = xgb.DMatrix(X, y)
|
m = xgb.DMatrix(X, y)
|
||||||
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class TestRanking:
|
|||||||
"""
|
"""
|
||||||
cls.dpath = 'demo/rank/'
|
cls.dpath = 'demo/rank/'
|
||||||
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
(x_train, y_train, qid_train, x_test, y_test, qid_test,
|
||||||
x_valid, y_valid, qid_valid) = tm.get_mq2008(cls.dpath)
|
x_valid, y_valid, qid_valid) = tm.data.get_mq2008(cls.dpath)
|
||||||
|
|
||||||
# instantiate the matrices
|
# instantiate the matrices
|
||||||
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
cls.dtrain = xgboost.DMatrix(x_train, y_train)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from xgboost import testing as tm
|
|||||||
from xgboost.testing.params import (
|
from xgboost.testing.params import (
|
||||||
cat_parameter_strategy,
|
cat_parameter_strategy,
|
||||||
exact_parameter_strategy,
|
exact_parameter_strategy,
|
||||||
|
hist_multi_parameter_strategy,
|
||||||
hist_parameter_strategy,
|
hist_parameter_strategy,
|
||||||
)
|
)
|
||||||
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
|
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
|
||||||
@ -18,11 +19,70 @@ from xgboost.testing.updater import check_init_estimation, check_quantile_loss
|
|||||||
|
|
||||||
def train_result(param, dmat, num_rounds):
|
def train_result(param, dmat, num_rounds):
|
||||||
result = {}
|
result = {}
|
||||||
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
|
booster = xgb.train(
|
||||||
evals_result=result)
|
param,
|
||||||
|
dmat,
|
||||||
|
num_rounds,
|
||||||
|
[(dmat, "train")],
|
||||||
|
verbose_eval=False,
|
||||||
|
evals_result=result,
|
||||||
|
)
|
||||||
|
assert booster.num_features() == dmat.num_col()
|
||||||
|
assert booster.num_boosted_rounds() == num_rounds
|
||||||
|
assert booster.feature_names == dmat.feature_names
|
||||||
|
assert booster.feature_types == dmat.feature_types
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TestTreeMethodMulti:
|
||||||
|
@given(
|
||||||
|
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
||||||
|
if dataset.name.endswith("-l1"):
|
||||||
|
return
|
||||||
|
param["tree_method"] = "exact"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
@given(
|
||||||
|
exact_parameter_strategy,
|
||||||
|
hist_parameter_strategy,
|
||||||
|
strategies.integers(1, 20),
|
||||||
|
tm.multi_dataset_strategy,
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_approx(self, param, hist_param, num_rounds, dataset):
|
||||||
|
param["tree_method"] = "approx"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
param.update(hist_param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
note(result)
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
@given(
|
||||||
|
exact_parameter_strategy,
|
||||||
|
hist_multi_parameter_strategy,
|
||||||
|
strategies.integers(1, 20),
|
||||||
|
tm.multi_dataset_strategy,
|
||||||
|
)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_hist(
|
||||||
|
self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset
|
||||||
|
) -> None:
|
||||||
|
if dataset.name.endswith("-l1"):
|
||||||
|
return
|
||||||
|
param["tree_method"] = "hist"
|
||||||
|
param = dataset.set_params(param)
|
||||||
|
param.update(hist_param)
|
||||||
|
result = train_result(param, dataset.get_dmat(), num_rounds)
|
||||||
|
note(result)
|
||||||
|
assert tm.non_increasing(result["train"][dataset.metric])
|
||||||
|
|
||||||
|
|
||||||
class TestTreeMethod:
|
class TestTreeMethod:
|
||||||
USE_ONEHOT = np.iinfo(np.int32).max
|
USE_ONEHOT = np.iinfo(np.int32).max
|
||||||
USE_PART = 1
|
USE_PART = 1
|
||||||
@ -77,10 +137,14 @@ class TestTreeMethod:
|
|||||||
# Second prune should not change the tree
|
# Second prune should not change the tree
|
||||||
assert after_prune == second_prune
|
assert after_prune == second_prune
|
||||||
|
|
||||||
@given(exact_parameter_strategy, hist_parameter_strategy, strategies.integers(1, 20),
|
@given(
|
||||||
tm.dataset_strategy)
|
exact_parameter_strategy,
|
||||||
|
hist_parameter_strategy,
|
||||||
|
strategies.integers(1, 20),
|
||||||
|
tm.dataset_strategy
|
||||||
|
)
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
def test_hist(self, param, hist_param, num_rounds, dataset):
|
def test_hist(self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
|
||||||
param['tree_method'] = 'hist'
|
param['tree_method'] = 'hist'
|
||||||
param = dataset.set_params(param)
|
param = dataset.set_params(param)
|
||||||
param.update(hist_param)
|
param.update(hist_param)
|
||||||
@ -88,23 +152,6 @@ class TestTreeMethod:
|
|||||||
note(result)
|
note(result)
|
||||||
assert tm.non_increasing(result['train'][dataset.metric])
|
assert tm.non_increasing(result['train'][dataset.metric])
|
||||||
|
|
||||||
@given(tm.sparse_datasets_strategy)
|
|
||||||
@settings(deadline=None, print_blob=True)
|
|
||||||
def test_sparse(self, dataset):
|
|
||||||
param = {"tree_method": "hist", "max_bin": 64}
|
|
||||||
hist_result = train_result(param, dataset.get_dmat(), 16)
|
|
||||||
note(hist_result)
|
|
||||||
assert tm.non_increasing(hist_result['train'][dataset.metric])
|
|
||||||
|
|
||||||
param = {"tree_method": "approx", "max_bin": 64}
|
|
||||||
approx_result = train_result(param, dataset.get_dmat(), 16)
|
|
||||||
note(approx_result)
|
|
||||||
assert tm.non_increasing(approx_result['train'][dataset.metric])
|
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
|
||||||
hist_result["train"]["rmse"], approx_result["train"]["rmse"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hist_categorical(self):
|
def test_hist_categorical(self):
|
||||||
# hist must be same as exact on all-categorial data
|
# hist must be same as exact on all-categorial data
|
||||||
dpath = 'demo/data/'
|
dpath = 'demo/data/'
|
||||||
@ -143,6 +190,23 @@ class TestTreeMethod:
|
|||||||
w = [0, 0, 1, 0]
|
w = [0, 0, 1, 0]
|
||||||
model.fit(X, y, sample_weight=w)
|
model.fit(X, y, sample_weight=w)
|
||||||
|
|
||||||
|
@given(tm.sparse_datasets_strategy)
|
||||||
|
@settings(deadline=None, print_blob=True)
|
||||||
|
def test_sparse(self, dataset):
|
||||||
|
param = {"tree_method": "hist", "max_bin": 64}
|
||||||
|
hist_result = train_result(param, dataset.get_dmat(), 16)
|
||||||
|
note(hist_result)
|
||||||
|
assert tm.non_increasing(hist_result['train'][dataset.metric])
|
||||||
|
|
||||||
|
param = {"tree_method": "approx", "max_bin": 64}
|
||||||
|
approx_result = train_result(param, dataset.get_dmat(), 16)
|
||||||
|
note(approx_result)
|
||||||
|
assert tm.non_increasing(approx_result['train'][dataset.metric])
|
||||||
|
|
||||||
|
np.testing.assert_allclose(
|
||||||
|
hist_result["train"]["rmse"], approx_result["train"]["rmse"]
|
||||||
|
)
|
||||||
|
|
||||||
def run_invalid_category(self, tree_method: str) -> None:
|
def run_invalid_category(self, tree_method: str) -> None:
|
||||||
rng = np.random.default_rng()
|
rng = np.random.default_rng()
|
||||||
# too large
|
# too large
|
||||||
@ -365,7 +429,7 @@ class TestTreeMethod:
|
|||||||
) -> None:
|
) -> None:
|
||||||
cat_parameters.update(hist_parameters)
|
cat_parameters.update(hist_parameters)
|
||||||
dataset = tm.TestDataset(
|
dataset = tm.TestDataset(
|
||||||
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
|
"ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse"
|
||||||
)
|
)
|
||||||
cat_parameters["tree_method"] = tree_method
|
cat_parameters["tree_method"] = tree_method
|
||||||
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
results = train_result(cat_parameters, dataset.get_dmat(), 16)
|
||||||
|
|||||||
@ -1168,7 +1168,7 @@ def test_dask_aft_survival() -> None:
|
|||||||
|
|
||||||
def test_dask_ranking(client: "Client") -> None:
|
def test_dask_ranking(client: "Client") -> None:
|
||||||
dpath = "demo/rank/"
|
dpath = "demo/rank/"
|
||||||
mq2008 = tm.get_mq2008(dpath)
|
mq2008 = tm.data.get_mq2008(dpath)
|
||||||
data = []
|
data = []
|
||||||
for d in mq2008:
|
for d in mq2008:
|
||||||
if isinstance(d, scipy.sparse.csr_matrix):
|
if isinstance(d, scipy.sparse.csr_matrix):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user