Initial support for multi-target tree. (#8616)

* Implement multi-target for hist.

- Add new hist tree builder.
- Move data fetchers for tests.
- Dispatch function calls in gbm base on the tree type.
This commit is contained in:
Jiaming Yuan 2023-03-22 23:49:56 +08:00 committed by GitHub
parent ea04d4c46c
commit 151882dd26
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 856 additions and 389 deletions

View File

@ -7,6 +7,12 @@ The demo is adopted from scikit-learn:
https://scikit-learn.org/stable/auto_examples/ensemble/plot_random_forest_regression_multioutput.html#sphx-glr-auto-examples-ensemble-plot-random-forest-regression-multioutput-py
See :doc:`/tutorials/multioutput` for more information.
.. note::
The feature is experimental. For the `multi_output_tree` strategy, many features are
missing.
"""
import argparse
@ -40,11 +46,18 @@ def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
return X, y
def rmse_model(plot_result: bool):
def rmse_model(plot_result: bool, strategy: str):
"""Draw a circle with 2-dim coordinate as target variables."""
X, y = gen_circle()
# Train a regressor on it
reg = xgb.XGBRegressor(tree_method="hist", n_estimators=64)
reg = xgb.XGBRegressor(
tree_method="hist",
n_estimators=128,
n_jobs=16,
max_depth=8,
multi_strategy=strategy,
subsample=0.6,
)
reg.fit(X, y, eval_set=[(X, y)])
y_predt = reg.predict(X)
@ -52,7 +65,7 @@ def rmse_model(plot_result: bool):
plot_predt(y, y_predt, "multi")
def custom_rmse_model(plot_result: bool) -> None:
def custom_rmse_model(plot_result: bool, strategy: str) -> None:
"""Train using Python implementation of Squared Error."""
# As the experimental support status, custom objective doesn't support matrix as
@ -88,9 +101,10 @@ def custom_rmse_model(plot_result: bool) -> None:
{
"tree_method": "hist",
"num_target": y.shape[1],
"multi_strategy": strategy,
},
dtrain=Xy,
num_boost_round=100,
num_boost_round=128,
obj=squared_log,
evals=[(Xy, "Train")],
evals_result=results,
@ -107,6 +121,16 @@ if __name__ == "__main__":
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
args = parser.parse_args()
# Train with builtin RMSE objective
rmse_model(args.plot == 1)
# - One model per output.
rmse_model(args.plot == 1, "one_output_per_tree")
# - One model for all outputs, this is still working in progress, many features are
# missing.
rmse_model(args.plot == 1, "multi_output_tree")
# Train with custom objective.
custom_rmse_model(args.plot == 1)
# - One model per output.
custom_rmse_model(args.plot == 1, "one_output_per_tree")
# - One model for all outputs, this is still working in progress, many features are
# missing.
custom_rmse_model(args.plot == 1, "multi_output_tree")

View File

@ -226,6 +226,18 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See :doc:`/tutorials/feature_interaction_constraint` for more information.
* ``multi_strategy``, [default = ``one_output_per_tree``]
.. versionadded:: 2.0.0
.. note:: This parameter is working-in-progress.
- The strategy used for training multi-target models, including multi-target regression
and multi-class classification. See :doc:`/tutorials/multioutput` for more information.
- ``one_output_per_tree``: One model for each target.
- ``multi_output_tree``: Use multi-target trees.
.. _cat-param:
Parameters for Categorical Feature

View File

@ -11,7 +11,11 @@ can be simultaneously classified as both sci-fi and comedy. For detailed explan
terminologies related to different multi-output models please refer to the
:doc:`scikit-learn user guide <sklearn:modules/multiclass>`.
Internally, XGBoost builds one model for each target similar to sklearn meta estimators,
**********************************
Training with One-Model-Per-Target
**********************************
By default, XGBoost builds one model for each target similar to sklearn meta estimators,
with the added benefit of reusing data and other integrated features like SHAP. For a
worked example of regression, see
:ref:`sphx_glr_python_examples_multioutput_regression.py`. For multi-label classification,
@ -36,3 +40,26 @@ dense matrix for labels.
The feature is still under development with limited support from objectives and metrics.
*************************
Training with Vector Leaf
*************************
.. versionadded:: 2.0
.. note::
This is still working-in-progress, and many features are missing.
XGBoost can optionally build multi-output trees with the size of leaf equals to the number
of targets when the tree method `hist` is used. The behavior can be controlled by the
``multi_strategy`` training parameter, which can take the value `one_output_per_tree` (the
default) for building one model per-target or `multi_output_tree` for building
multi-output trees.
.. code-block:: python
clf = xgb.XGBClassifier(tree_method="hist", multi_strategy="multi_output_tree")
See :ref:`sphx_glr_python_examples_multioutput_regression.py` for a worked example with
regression.

View File

@ -286,8 +286,8 @@ struct LearnerModelParamLegacy;
* \brief Strategy for building multi-target models.
*/
enum class MultiStrategy : std::int32_t {
kComposite = 0,
kMonolithic = 1,
kOneOutputPerTree = 0,
kMultiOutputTree = 1,
};
/**
@ -317,7 +317,7 @@ struct LearnerModelParam {
/**
* \brief Strategy for building multi-target models.
*/
MultiStrategy multi_strategy{MultiStrategy::kComposite};
MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree};
LearnerModelParam() = default;
// As the old `LearnerModelParamLegacy` is still used by binary IO, we keep
@ -338,7 +338,7 @@ struct LearnerModelParam {
void Copy(LearnerModelParam const& that);
[[nodiscard]] bool IsVectorLeaf() const noexcept {
return multi_strategy == MultiStrategy::kMonolithic;
return multi_strategy == MultiStrategy::kMultiOutputTree;
}
[[nodiscard]] bst_target_t OutputLength() const noexcept { return this->num_output_group; }
[[nodiscard]] bst_target_t LeafLength() const noexcept {

View File

@ -530,17 +530,17 @@ class TensorView {
/**
* \brief Number of items in the tensor.
*/
LINALG_HD [[nodiscard]] std::size_t Size() const { return size_; }
[[nodiscard]] LINALG_HD std::size_t Size() const { return size_; }
/**
* \brief Whether this is a contiguous array, both C and F contiguous returns true.
*/
LINALG_HD [[nodiscard]] bool Contiguous() const {
[[nodiscard]] LINALG_HD bool Contiguous() const {
return data_.size() == this->Size() || this->CContiguous() || this->FContiguous();
}
/**
* \brief Whether it's a c-contiguous array.
*/
LINALG_HD [[nodiscard]] bool CContiguous() const {
[[nodiscard]] LINALG_HD bool CContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape.
@ -550,7 +550,7 @@ class TensorView {
/**
* \brief Whether it's a f-contiguous array.
*/
LINALG_HD [[nodiscard]] bool FContiguous() const {
[[nodiscard]] LINALG_HD bool FContiguous() const {
StrideT stride;
static_assert(std::is_same<decltype(stride), decltype(stride_)>::value);
// It's contiguous if the stride can be calculated from shape.

View File

@ -312,6 +312,19 @@ __model_doc = f"""
needs to be set to have categorical feature support. See :doc:`Categorical Data
</tutorials/categorical>` and :ref:`cat-param` for details.
multi_strategy : Optional[str]
.. versionadded:: 2.0.0
.. note:: This parameter is working-in-progress.
The strategy used for training multi-target models, including multi-target
regression and multi-class classification. See :doc:`/tutorials/multioutput` for
more information.
- ``one_output_per_tree``: One model for each target.
- ``multi_output_tree``: Use multi-target trees.
eval_metric : Optional[Union[str, List[str], Callable]]
.. versionadded:: 1.6.0
@ -624,6 +637,7 @@ class XGBModel(XGBModelBase):
feature_types: Optional[FeatureTypes] = None,
max_cat_to_onehot: Optional[int] = None,
max_cat_threshold: Optional[int] = None,
multi_strategy: Optional[str] = None,
eval_metric: Optional[Union[str, List[str], Callable]] = None,
early_stopping_rounds: Optional[int] = None,
callbacks: Optional[List[TrainingCallback]] = None,
@ -670,6 +684,7 @@ class XGBModel(XGBModelBase):
self.feature_types = feature_types
self.max_cat_to_onehot = max_cat_to_onehot
self.max_cat_threshold = max_cat_threshold
self.multi_strategy = multi_strategy
self.eval_metric = eval_metric
self.early_stopping_rounds = early_stopping_rounds
self.callbacks = callbacks

View File

@ -10,11 +10,9 @@ import os
import platform
import socket
import sys
import zipfile
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from io import StringIO
from pathlib import Path
from platform import system
from typing import (
Any,
@ -29,7 +27,6 @@ from typing import (
TypedDict,
Union,
)
from urllib import request
import numpy as np
import pytest
@ -38,6 +35,13 @@ from scipy import sparse
import xgboost as xgb
from xgboost.core import ArrayLike
from xgboost.sklearn import SklObjective
from xgboost.testing.data import (
get_california_housing,
get_cancer,
get_digits,
get_sparse,
memory,
)
hypothesis = pytest.importorskip("hypothesis")
@ -45,13 +49,8 @@ hypothesis = pytest.importorskip("hypothesis")
from hypothesis import strategies
from hypothesis.extra.numpy import arrays
joblib = pytest.importorskip("joblib")
datasets = pytest.importorskip("sklearn.datasets")
Memory = joblib.Memory
memory = Memory("./cachedir", verbose=0)
PytestSkip = TypedDict("PytestSkip", {"condition": bool, "reason": str})
@ -353,137 +352,6 @@ class TestDataset:
return self.name
@memory.cache
def get_california_housing() -> Tuple[np.ndarray, np.ndarray]:
data = datasets.fetch_california_housing()
return data.data, data.target
@memory.cache
def get_digits() -> Tuple[np.ndarray, np.ndarray]:
data = datasets.load_digits()
return data.data, data.target
@memory.cache
def get_cancer() -> Tuple[np.ndarray, np.ndarray]:
return datasets.load_breast_cancer(return_X_y=True)
@memory.cache
def get_sparse() -> Tuple[np.ndarray, np.ndarray]:
rng = np.random.RandomState(199)
n = 2000
sparsity = 0.75
X, y = datasets.make_regression(n, random_state=rng)
flag = rng.binomial(1, sparsity, X.shape)
for i in range(X.shape[0]):
for j in range(X.shape[1]):
if flag[i, j]:
X[i, j] = np.nan
return X, y
@memory.cache
def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]:
"""
Number of samples: 1460
Number of features: 20
Number of categorical features: 10
Number of numerical features: 10
"""
from sklearn.datasets import fetch_openml
X, y = fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
categorical_columns_subset: List[str] = [
"BldgType", # 5 cats, no nan
"GarageFinish", # 3 cats, nan
"LotConfig", # 5 cats, no nan
"Functional", # 7 cats, no nan
"MasVnrType", # 4 cats, nan
"HouseStyle", # 8 cats, no nan
"FireplaceQu", # 5 cats, nan
"ExterCond", # 5 cats, no nan
"ExterQual", # 4 cats, no nan
"PoolQC", # 3 cats, nan
]
numerical_columns_subset: List[str] = [
"3SsnPorch",
"Fireplaces",
"BsmtHalfBath",
"HalfBath",
"GarageCars",
"TotRmsAbvGrd",
"BsmtFinSF1",
"BsmtFinSF2",
"GrLivArea",
"ScreenPorch",
]
X = X[categorical_columns_subset + numerical_columns_subset]
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")
return X, y
@memory.cache
def get_mq2008(
dpath: str,
) -> Tuple[
sparse.csr_matrix,
np.ndarray,
np.ndarray,
sparse.csr_matrix,
np.ndarray,
np.ndarray,
sparse.csr_matrix,
np.ndarray,
np.ndarray,
]:
from sklearn.datasets import load_svmlight_files
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
if not os.path.exists(target):
request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, "r") as f:
f.extractall(path=dpath)
(
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
) = load_svmlight_files(
(
Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
),
query_id=True,
zero_based=False,
)
return (
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
)
# pylint: disable=too-many-arguments,too-many-locals
@memory.cache
def make_categorical(
@ -738,20 +606,7 @@ _unweighted_datasets_strategy = strategies.sampled_from(
TestDataset(
"calif_housing-l1", get_california_housing, "reg:absoluteerror", "mae"
),
TestDataset("digits", get_digits, "multi:softmax", "mlogloss"),
TestDataset("cancer", get_cancer, "binary:logistic", "logloss"),
TestDataset(
"mtreg",
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
"reg:squarederror",
"rmse",
),
TestDataset(
"mtreg-l1",
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
"reg:absoluteerror",
"mae",
),
TestDataset("sparse", get_sparse, "reg:squarederror", "rmse"),
TestDataset("sparse-l1", get_sparse, "reg:absoluteerror", "mae"),
TestDataset(
@ -764,9 +619,17 @@ _unweighted_datasets_strategy = strategies.sampled_from(
)
def make_datasets_with_margin(
unweighted_strategy: strategies.SearchStrategy,
) -> Callable:
"""Factory function for creating strategies that generates datasets with weight and
base margin.
"""
@strategies.composite
def _dataset_weight_margin(draw: Callable) -> TestDataset:
data: TestDataset = draw(_unweighted_datasets_strategy)
def weight_margin(draw: Callable) -> TestDataset:
data: TestDataset = draw(unweighted_strategy)
if draw(strategies.booleans()):
data.w = draw(
arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0))
@ -791,10 +654,36 @@ def _dataset_weight_margin(draw: Callable) -> TestDataset:
return data
return weight_margin
# A strategy for drawing from a set of example datasets
# May add random weights to the dataset
dataset_strategy = _dataset_weight_margin()
# A strategy for drawing from a set of example datasets. May add random weights to the
# dataset
dataset_strategy = make_datasets_with_margin(_unweighted_datasets_strategy)()
_unweighted_multi_datasets_strategy = strategies.sampled_from(
[
TestDataset("digits", get_digits, "multi:softmax", "mlogloss"),
TestDataset(
"mtreg",
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
"reg:squarederror",
"rmse",
),
TestDataset(
"mtreg-l1",
lambda: datasets.make_regression(n_samples=128, n_features=2, n_targets=3),
"reg:absoluteerror",
"mae",
),
]
)
# A strategy for drawing from a set of multi-target/multi-class datasets.
multi_dataset_strategy = make_datasets_with_margin(
_unweighted_multi_datasets_strategy
)()
def non_increasing(L: Sequence[float], tolerance: float = 1e-4) -> bool:

View File

@ -1,13 +1,20 @@
"""Utilities for data generation."""
from typing import Any, Generator, Tuple, Union
import os
import zipfile
from typing import Any, Generator, List, Tuple, Union
from urllib import request
import numpy as np
import pytest
from numpy.random import Generator as RNG
from scipy import sparse
import xgboost
from xgboost.data import pandas_pyarrow_mapper
joblib = pytest.importorskip("joblib")
memory = joblib.Memory("./cachedir", verbose=0)
def np_dtypes(
n_samples: int, n_features: int
@ -195,3 +202,141 @@ def check_inf(rng: RNG) -> None:
with pytest.raises(ValueError, match="Input data contains `inf`"):
xgboost.DMatrix(X, y)
@memory.cache
def get_california_housing() -> Tuple[np.ndarray, np.ndarray]:
"""Fetch the California housing dataset from sklearn."""
datasets = pytest.importorskip("sklearn.datasets")
data = datasets.fetch_california_housing()
return data.data, data.target
@memory.cache
def get_digits() -> Tuple[np.ndarray, np.ndarray]:
"""Fetch the digits dataset from sklearn."""
datasets = pytest.importorskip("sklearn.datasets")
data = datasets.load_digits()
return data.data, data.target
@memory.cache
def get_cancer() -> Tuple[np.ndarray, np.ndarray]:
"""Fetch the breast cancer dataset from sklearn."""
datasets = pytest.importorskip("sklearn.datasets")
return datasets.load_breast_cancer(return_X_y=True)
@memory.cache
def get_sparse() -> Tuple[np.ndarray, np.ndarray]:
"""Generate a sparse dataset."""
datasets = pytest.importorskip("sklearn.datasets")
rng = np.random.RandomState(199)
n = 2000
sparsity = 0.75
X, y = datasets.make_regression(n, random_state=rng)
flag = rng.binomial(1, sparsity, X.shape)
for i in range(X.shape[0]):
for j in range(X.shape[1]):
if flag[i, j]:
X[i, j] = np.nan
return X, y
@memory.cache
def get_ames_housing() -> Tuple[np.ndarray, np.ndarray]:
"""
Number of samples: 1460
Number of features: 20
Number of categorical features: 10
Number of numerical features: 10
"""
datasets = pytest.importorskip("sklearn.datasets")
X, y = datasets.fetch_openml(data_id=42165, as_frame=True, return_X_y=True)
categorical_columns_subset: List[str] = [
"BldgType", # 5 cats, no nan
"GarageFinish", # 3 cats, nan
"LotConfig", # 5 cats, no nan
"Functional", # 7 cats, no nan
"MasVnrType", # 4 cats, nan
"HouseStyle", # 8 cats, no nan
"FireplaceQu", # 5 cats, nan
"ExterCond", # 5 cats, no nan
"ExterQual", # 4 cats, no nan
"PoolQC", # 3 cats, nan
]
numerical_columns_subset: List[str] = [
"3SsnPorch",
"Fireplaces",
"BsmtHalfBath",
"HalfBath",
"GarageCars",
"TotRmsAbvGrd",
"BsmtFinSF1",
"BsmtFinSF2",
"GrLivArea",
"ScreenPorch",
]
X = X[categorical_columns_subset + numerical_columns_subset]
X[categorical_columns_subset] = X[categorical_columns_subset].astype("category")
return X, y
@memory.cache
def get_mq2008(
dpath: str,
) -> Tuple[
sparse.csr_matrix,
np.ndarray,
np.ndarray,
sparse.csr_matrix,
np.ndarray,
np.ndarray,
sparse.csr_matrix,
np.ndarray,
np.ndarray,
]:
"""Fetch the mq2008 dataset."""
datasets = pytest.importorskip("sklearn.datasets")
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
target = os.path.join(dpath, "MQ2008.zip")
if not os.path.exists(target):
request.urlretrieve(url=src, filename=target)
with zipfile.ZipFile(target, "r") as f:
f.extractall(path=dpath)
(
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
) = datasets.load_svmlight_files(
(
os.path.join(dpath, "MQ2008/Fold1/train.txt"),
os.path.join(dpath, "MQ2008/Fold1/test.txt"),
os.path.join(dpath, "MQ2008/Fold1/vali.txt"),
),
query_id=True,
zero_based=False,
)
return (
x_train,
y_train,
qid_train,
x_test,
y_test,
qid_test,
x_valid,
y_valid,
qid_valid,
)

View File

@ -4,8 +4,8 @@ from typing import cast
import pytest
hypothesis = pytest.importorskip("hypothesis")
from hypothesis import strategies # pylint:disable=wrong-import-position
strategies = pytest.importorskip("hypothesis.strategies")
exact_parameter_strategy = strategies.fixed_dictionaries(
{
@ -41,6 +41,26 @@ hist_parameter_strategy = strategies.fixed_dictionaries(
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
)
hist_multi_parameter_strategy = strategies.fixed_dictionaries(
{
"max_depth": strategies.integers(1, 11),
"max_leaves": strategies.integers(0, 1024),
"max_bin": strategies.integers(2, 512),
"multi_strategy": strategies.sampled_from(
["multi_output_tree", "one_output_per_tree"]
),
"grow_policy": strategies.sampled_from(["lossguide", "depthwise"]),
"min_child_weight": strategies.floats(0.5, 2.0),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
"colsample_bytree": strategies.floats(0.5, 1.0),
"colsample_bylevel": strategies.floats(0.5, 1.0),
}
).filter(
lambda x: (cast(int, x["max_depth"]) > 0 or cast(int, x["max_leaves"]) > 0)
and (cast(int, x["max_depth"]) > 0 or x["grow_policy"] == "lossguide")
)
cat_parameter_strategy = strategies.fixed_dictionaries(
{
"max_cat_to_onehot": strategies.integers(1, 128),

View File

@ -55,6 +55,7 @@ inline void CalcPredictShape(bool strict_shape, PredictionType type, size_t rows
*out_dim = 2;
shape.resize(*out_dim);
shape.front() = rows;
// chunksize can be 1 if it's softmax
shape.back() = std::min(groups, chunksize);
}
break;

View File

@ -359,6 +359,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b
HistogramCuts *cuts) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
auto &cut_values = cuts->cut_values_.HostVector();
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
@ -419,8 +420,8 @@ void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
} else {
AddCutPoint<WQSketch>(a, max_num_bins, cuts);
// push a value that is greater than anything
const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value
: cuts->min_vals_.HostVector()[fid];
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid];
// this must be bigger than last value in a scale
const bst_float last = cpt + (fabs(cpt) + 1e-5f);
cuts->cut_values_.HostVector().push_back(last);

View File

@ -352,19 +352,6 @@ struct WQSummary {
prev_rmax = data[i].rmax;
}
}
// check consistency of the summary
inline bool Check(const char *msg) const {
const float tol = 10.0f;
for (size_t i = 0; i < this->size; ++i) {
if (data[i].rmin + data[i].wmin > data[i].rmax + tol ||
data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) {
LOG(INFO) << "---------- WQSummary::Check did not pass ----------";
this->Print();
return false;
}
}
return true;
}
};
/*! \brief try to do efficient pruning */

View File

@ -257,6 +257,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
}
iter.Reset();
CHECK_EQ(rbegin, Info().num_row_);
CHECK_EQ(this->ghist_->Features(), Info().num_col_);
/**
* Generate column matrix

View File

@ -10,6 +10,7 @@
#include <dmlc/parameter.h>
#include <algorithm>
#include <cinttypes> // for uint32_t
#include <limits>
#include <memory>
#include <string>
@ -27,9 +28,11 @@
#include "xgboost/host_device_vector.h"
#include "xgboost/json.h"
#include "xgboost/logging.h"
#include "xgboost/model.h"
#include "xgboost/objective.h"
#include "xgboost/predictor.h"
#include "xgboost/string_view.h"
#include "xgboost/string_view.h" // for StringView
#include "xgboost/tree_model.h" // for RegTree
#include "xgboost/tree_updater.h"
namespace xgboost::gbm {
@ -131,6 +134,12 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
// set, since only experts are expected to do so.
return;
}
if (model_.learner_model_param->IsVectorLeaf()) {
CHECK(tparam_.tree_method == TreeMethod::kHist)
<< "Only the hist tree method is supported for building multi-target trees with vector "
"leaf.";
}
// tparam_ is set before calling this function.
if (tparam_.tree_method != TreeMethod::kAuto) {
return;
@ -175,12 +184,12 @@ void GBTree::ConfigureUpdaters() {
case TreeMethod::kExact:
tparam_.updater_seq = "grow_colmaker,prune";
break;
case TreeMethod::kHist:
LOG(INFO) <<
"Tree method is selected to be 'hist', which uses a "
"single updater grow_quantile_histmaker.";
case TreeMethod::kHist: {
LOG(INFO) << "Tree method is selected to be 'hist', which uses a single updater "
"grow_quantile_histmaker.";
tparam_.updater_seq = "grow_quantile_histmaker";
break;
}
case TreeMethod::kGPUHist: {
common::AssertGPUSupport();
tparam_.updater_seq = "grow_gpu_hist";
@ -209,11 +218,9 @@ void CopyGradient(HostDeviceVector<GradientPair> const* in_gpair, int32_t n_thre
GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair);
} else {
std::vector<GradientPair> &tmp_h = out_gpair->HostVector();
auto nsize = static_cast<bst_omp_uint>(out_gpair->Size());
const auto& gpair_h = in_gpair->ConstHostVector();
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
tmp_h[i] = gpair_h[i * n_groups + group_id];
});
common::ParallelFor(out_gpair->Size(), n_threads,
[&](auto i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; });
}
}
@ -234,6 +241,7 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
CHECK_EQ(model_.param.num_parallel_tree, trees.size());
CHECK_EQ(model_.param.num_parallel_tree, 1)
<< "Boosting random forest is not supported for current objective.";
CHECK(!trees.front()->IsMultiTarget()) << "Update tree leaf" << MTNotImplemented();
CHECK_EQ(trees.size(), model_.param.num_parallel_tree);
for (std::size_t tree_idx = 0; tree_idx < trees.size(); ++tree_idx) {
auto const& position = node_position.at(tree_idx);
@ -245,17 +253,18 @@ void GBTree::UpdateTreeLeaf(DMatrix const* p_fmat, HostDeviceVector<float> const
void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry* predt, ObjFunction const* obj) {
std::vector<std::vector<std::unique_ptr<RegTree>>> new_trees;
const int ngroup = model_.learner_model_param->num_output_group;
const int ngroup = model_.learner_model_param->OutputLength();
ConfigureWithKnownData(this->cfg_, p_fmat);
monitor_.Start("BoostNewTrees");
// Weird case that tree method is cpu-based but gpu_id is set. Ideally we should let
// `gpu_id` be the single source of determining what algorithms to run, but that will
// break a lots of existing code.
auto device = tparam_.tree_method != TreeMethod::kGPUHist ? Context::kCpuId : ctx_->gpu_id;
auto out = linalg::TensorView<float, 2>{
auto out = linalg::MakeTensorView(
device,
device == Context::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(),
{static_cast<size_t>(p_fmat->Info().num_row_), static_cast<size_t>(ngroup)},
device};
p_fmat->Info().num_row_, model_.learner_model_param->OutputLength());
CHECK_NE(ngroup, 0);
if (!p_fmat->SingleColBlock() && obj->Task().UpdateTreeLeaf()) {
@ -266,7 +275,13 @@ void GBTree::DoBoost(DMatrix* p_fmat, HostDeviceVector<GradientPair>* in_gpair,
// position is negated if the row is sampled out.
std::vector<HostDeviceVector<bst_node_t>> node_position;
if (ngroup == 1) {
if (model_.learner_model_param->IsVectorLeaf()) {
std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
// No update prediction cache yet.
new_trees.push_back(std::move(ret));
} else if (model_.learner_model_param->OutputLength() == 1) {
std::vector<std::unique_ptr<RegTree>> ret;
BoostNewTrees(in_gpair, p_fmat, 0, &node_position, &ret);
UpdateTreeLeaf(p_fmat, predt->predictions, obj, 0, node_position, &ret);
@ -383,11 +398,15 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
}
// update the trees
CHECK_EQ(gpair->Size(), p_fmat->Info().num_row_)
<< "Mismatching size between number of rows from input data and size of "
"gradient vector.";
auto n_out = model_.learner_model_param->OutputLength() * p_fmat->Info().num_row_;
StringView msg{
"Mismatching size between number of rows from input data and size of gradient vector."};
if (!model_.learner_model_param->IsVectorLeaf() && p_fmat->Info().num_row_ != 0) {
CHECK_EQ(n_out % gpair->Size(), 0) << msg;
} else {
CHECK_EQ(gpair->Size(), n_out) << msg;
}
CHECK(out_position);
out_position->resize(new_trees.size());
// Rescale learning rate according to the size of trees
@ -402,9 +421,13 @@ void GBTree::BoostNewTrees(HostDeviceVector<GradientPair>* gpair, DMatrix* p_fma
void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees) {
monitor_.Start("CommitModel");
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
if (this->model_.learner_model_param->IsVectorLeaf()) {
model_.CommitModel(std::move(new_trees[0]), 0);
} else {
for (std::uint32_t gid = 0; gid < model_.learner_model_param->OutputLength(); ++gid) {
model_.CommitModel(std::move(new_trees[gid]), gid);
}
}
monitor_.Stop("CommitModel");
}
@ -564,11 +587,10 @@ void GBTree::PredictBatch(DMatrix* p_fmat,
if (out_preds->version == 0) {
// out_preds->Size() can be non-zero as it's initialized here before any
// tree is built at the 0^th iterator.
predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions,
model_);
predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions, model_);
}
uint32_t tree_begin, tree_end;
std::uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
if (tree_end > tree_begin) {
@ -577,7 +599,7 @@ void GBTree::PredictBatch(DMatrix* p_fmat,
if (reset) {
out_preds->version = 0;
} else {
uint32_t delta = layer_end - out_preds->version;
std::uint32_t delta = layer_end - out_preds->version;
out_preds->Update(delta);
}
}
@ -770,6 +792,7 @@ class Dart : public GBTree {
void PredictBatchImpl(DMatrix *p_fmat, PredictionCacheEntry *p_out_preds,
bool training, unsigned layer_begin,
unsigned layer_end) const {
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
CHECK(predictor);
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
@ -830,6 +853,7 @@ class Dart : public GBTree {
void InplacePredict(std::shared_ptr<DMatrix> p_fmat, float missing,
PredictionCacheEntry* p_out_preds, uint32_t layer_begin,
unsigned layer_end) const override {
CHECK(!this->model_.learner_model_param->IsVectorLeaf()) << "dart" << MTNotImplemented();
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, layer_begin, layer_end);
auto n_groups = model_.learner_model_param->num_output_group;

View File

@ -140,13 +140,21 @@ struct DartTrainParam : public XGBoostParameter<DartTrainParam> {
namespace detail {
// From here on, layer becomes concrete trees.
inline std::pair<uint32_t, uint32_t> LayerToTree(gbm::GBTreeModel const& model,
size_t layer_begin,
size_t layer_end) {
bst_group_t groups = model.learner_model_param->num_output_group;
uint32_t tree_begin = layer_begin * groups * model.param.num_parallel_tree;
uint32_t tree_end = layer_end * groups * model.param.num_parallel_tree;
std::uint32_t layer_begin,
std::uint32_t layer_end) {
std::uint32_t tree_begin;
std::uint32_t tree_end;
if (model.learner_model_param->IsVectorLeaf()) {
tree_begin = layer_begin * model.param.num_parallel_tree;
tree_end = layer_end * model.param.num_parallel_tree;
} else {
bst_group_t groups = model.learner_model_param->OutputLength();
tree_begin = layer_begin * groups * model.param.num_parallel_tree;
tree_end = layer_end * groups * model.param.num_parallel_tree;
}
if (tree_end == 0) {
tree_end = static_cast<uint32_t>(model.trees.size());
tree_end = model.trees.size();
}
if (model.trees.size() != 0) {
CHECK_LE(tree_begin, tree_end);
@ -234,22 +242,25 @@ class GBTree : public GradientBooster {
void LoadModel(Json const& in) override;
// Number of trees per layer.
auto LayerTrees() const {
auto n_trees = model_.learner_model_param->num_output_group * model_.param.num_parallel_tree;
return n_trees;
[[nodiscard]] std::uint32_t LayerTrees() const {
if (model_.learner_model_param->IsVectorLeaf()) {
return model_.param.num_parallel_tree;
}
return model_.param.num_parallel_tree * model_.learner_model_param->OutputLength();
}
// slice the trees, out must be already allocated
void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const override;
int32_t BoostedRounds() const override {
[[nodiscard]] std::int32_t BoostedRounds() const override {
CHECK_NE(model_.param.num_parallel_tree, 0);
CHECK_NE(model_.learner_model_param->num_output_group, 0);
return model_.trees.size() / this->LayerTrees();
}
bool ModelFitted() const override {
[[nodiscard]] bool ModelFitted() const override {
return !model_.trees.empty() || !model_.trees_to_update.empty();
}

View File

@ -326,7 +326,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
std::string booster;
std::string objective;
// This is a training parameter and is not saved (nor loaded) in the model.
MultiStrategy multi_strategy{MultiStrategy::kComposite};
MultiStrategy multi_strategy{MultiStrategy::kOneOutputPerTree};
// declare parameters
DMLC_DECLARE_PARAMETER(LearnerTrainParam) {
@ -339,12 +339,12 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.set_default("reg:squarederror")
.describe("Objective function used for obtaining gradient.");
DMLC_DECLARE_FIELD(multi_strategy)
.add_enum("composite", MultiStrategy::kComposite)
.add_enum("monolithic", MultiStrategy::kMonolithic)
.set_default(MultiStrategy::kComposite)
.add_enum("one_output_per_tree", MultiStrategy::kOneOutputPerTree)
.add_enum("multi_output_tree", MultiStrategy::kMultiOutputTree)
.set_default(MultiStrategy::kOneOutputPerTree)
.describe(
"Strategy used for training multi-target models. `monolithic` means building one "
"single tree for all targets.");
"Strategy used for training multi-target models. `multi_output_tree` means building "
"one single tree for all targets.");
}
};

View File

@ -145,7 +145,6 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info,
auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size());
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = info.group_ptr_.size() - 1;
auto d_inv_idcg = p_cache->InvIDCG(ctx);
auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values());
@ -171,7 +170,6 @@ PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info,
HostDeviceVector<float> const &predt, bool minus,
std::shared_ptr<ltr::MAPCache> p_cache) {
auto d_group_ptr = p_cache->DataGroupPtr(ctx);
auto n_groups = info.group_ptr_.size() - 1;
auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
predt.SetDevice(ctx->gpu_id);

View File

@ -87,30 +87,6 @@ bst_float PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree,
: GetLeafIndex<false, has_categorical>(tree, p_feats, cats);
return tree[leaf].LeafValue();
}
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, const size_t predict_offset,
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
const size_t gid = model.tree_info[tree_id];
auto const &tree = *model.trees[tree_id];
auto const &cats = tree.GetCategoriesMatrix();
auto has_categorical = tree.HasCategoricalSplit();
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
}
}
}
} // namespace scalar
namespace multi {
@ -128,7 +104,7 @@ bst_node_t GetLeafIndex(MultiTargetTree const &tree, const RegTree::FVec &feat,
}
template <bool has_categorical>
void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tree,
void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tree,
RegTree::CategoricalSplitMatrix const &cats,
linalg::VectorView<float> out_predt) {
bst_node_t const leaf = p_feats.HasMissing()
@ -140,32 +116,48 @@ void PredValueByOneTree(const RegTree::FVec &p_feats, MultiTargetTree const &tre
out_predt(i) += leaf_value(i);
}
}
} // namespace multi
void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin,
const size_t tree_end, const size_t predict_offset,
const std::vector<RegTree::FVec> &thread_temp, const size_t offset,
const size_t block_size, linalg::TensorView<float, 2> out_predt) {
for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
namespace {
void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_begin,
std::uint32_t const tree_end, std::size_t const predict_offset,
std::vector<RegTree::FVec> const &thread_temp, std::size_t const offset,
std::size_t const block_size, linalg::MatrixView<float> out_predt) {
for (std::uint32_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) {
auto const &tree = *model.trees.at(tree_id);
auto cats = tree.GetCategoriesMatrix();
auto const &cats = tree.GetCategoriesMatrix();
bool has_categorical = tree.HasCategoricalSplit();
if (tree.IsMultiTarget()) {
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
multi::PredValueByOneTree<true>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
t_predts);
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
auto t_predts = out_predt.Slice(predict_offset + i, linalg::All());
PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(), cats,
t_predts);
multi::PredValueByOneTree<false>(thread_temp[offset + i], *tree.GetMultiTargetTree(),
cats, t_predts);
}
}
} else {
auto const gid = model.tree_info[tree_id];
if (has_categorical) {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
} else {
for (std::size_t i = 0; i < block_size; ++i) {
out_predt(predict_offset + i, gid) +=
scalar::PredValueByOneTree<true>(thread_temp[offset + i], tree, cats);
}
}
}
}
}
} // namespace multi
template <typename DataView>
void FVecFill(const size_t block_size, const size_t batch_offset, const int num_feature,
@ -190,9 +182,7 @@ void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batc
}
}
namespace {
static std::size_t constexpr kUnroll = 8;
} // anonymous namespace
struct SparsePageView {
bst_row_t base_rowid;
@ -292,7 +282,7 @@ class AdapterView {
template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model,
int32_t tree_begin, int32_t tree_end,
std::uint32_t tree_begin, std::uint32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads,
linalg::TensorView<float, 2> out_predt) {
auto &thread_temp = *p_thread_temp;
@ -310,14 +300,8 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &mod
FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp);
// process block of rows through all trees to keep cache locality
if (model.learner_model_param->IsVectorLeaf()) {
multi::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
thread_temp, fvec_offset, block_size, out_predt);
} else {
scalar::PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid,
thread_temp, fvec_offset, block_size, out_predt);
}
PredictByAllTrees(model, tree_begin, tree_end, batch_offset + batch.base_rowid, thread_temp,
fvec_offset, block_size, out_predt);
FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp);
});
}
@ -348,7 +332,6 @@ void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
FillNodeMeanValues(tree, 0, mean_values);
}
namespace {
// init thread buffers
static void InitThreadTemp(int nthread, std::vector<RegTree::FVec> *out) {
int prev_thread_temp_size = out->size();

View File

@ -411,7 +411,7 @@ class DeviceModel {
this->tree_beg_ = tree_begin;
this->tree_end_ = tree_end;
this->num_group = model.learner_model_param->num_output_group;
this->num_group = model.learner_model_param->OutputLength();
}
};

View File

@ -306,9 +306,9 @@ class HistogramBuilder {
// Construct a work space for building histogram. Eventually we should move this
// function into histogram builder once hist tree method supports external memory.
template <typename Partitioner>
template <typename Partitioner, typename ExpandEntry = CPUExpandEntry>
common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners,
std::vector<CPUExpandEntry> const &nodes_to_build) {
std::vector<ExpandEntry> const &nodes_to_build) {
std::vector<size_t> partition_size(nodes_to_build.size(), 0);
for (auto const &partition : partitioners) {
size_t k = 0;

View File

@ -889,6 +889,8 @@ void RegTree::Save(dmlc::Stream* fo) const {
CHECK_EQ(param_.num_nodes, static_cast<int>(stats_.size()));
CHECK_EQ(param_.deprecated_num_roots, 1);
CHECK_NE(param_.num_nodes, 0);
CHECK(!IsMultiTarget())
<< "Please use JSON/UBJSON for saving models with multi-target trees.";
CHECK(!HasCategoricalSplit())
<< "Please use JSON/UBJSON for saving models with categorical splits.";

View File

@ -4,36 +4,39 @@
* \brief use quantized feature values to construct a tree
* \author Philip Cho, Tianqi Checn, Egor Smirnov
*/
#include <algorithm> // for max
#include <algorithm> // for max, copy, transform
#include <cstddef> // for size_t
#include <cstdint> // for uint32_t
#include <memory> // for unique_ptr, allocator, make_unique, make_shared
#include <ostream> // for operator<<, char_traits, basic_ostream
#include <tuple> // for apply
#include <cstdint> // for uint32_t, int32_t
#include <memory> // for unique_ptr, allocator, make_unique, shared_ptr
#include <numeric> // for accumulate
#include <ostream> // for basic_ostream, char_traits, operator<<
#include <utility> // for move, swap
#include <vector> // for vector
#include "../collective/communicator-inl.h" // for Allreduce, IsDistributed
#include "../collective/communicator.h" // for Operation
#include "../common/hist_util.h" // for HistogramCuts, HistCollection
#include "../common/linalg_op.h" // for begin, cbegin, cend
#include "../common/random.h" // for ColumnSampler
#include "../common/threading_utils.h" // for ParallelFor
#include "../common/timer.h" // for Monitor
#include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter
#include "../data/gradient_index.h" // for GHistIndexMatrix
#include "common_row_partitioner.h" // for CommonRowPartitioner
#include "dmlc/omp.h" // for omp_get_thread_num
#include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG
#include "driver.h" // for Driver
#include "hist/evaluate_splits.h" // for HistEvaluator, UpdatePredictionCacheImpl
#include "hist/expand_entry.h" // for CPUExpandEntry
#include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre...
#include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry
#include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace
#include "hist/sampler.h" // for SampleGradient
#include "param.h" // for TrainParam, GradStats
#include "xgboost/base.h" // for GradientPair, GradientPairInternal, bst_node_t
#include "param.h" // for TrainParam, SplitEntryContainer, GradStats
#include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ...
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for BatchIterator, BatchSet, DMatrix, MetaInfo
#include "xgboost/host_device_vector.h" // for HostDeviceVector
#include "xgboost/linalg.h" // for TensorView, MatrixView, UnravelIndex, All
#include "xgboost/logging.h" // for LogCheck_EQ, LogCheck_GE, CHECK_EQ, LOG, LOG...
#include "xgboost/linalg.h" // for All, MatrixView, TensorView, Matrix, Empty
#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK, LogCheck_GE
#include "xgboost/span.h" // for Span, operator!=, SpanIterator
#include "xgboost/string_view.h" // for operator<<
#include "xgboost/task.h" // for ObjInfo
@ -105,6 +108,212 @@ void UpdateTree(common::Monitor *monitor_, linalg::MatrixView<GradientPair const
monitor_->Stop(__func__);
}
/**
* \brief Updater for building multi-target trees. The implementation simply iterates over
* each target.
*/
class MultiTargetHistBuilder {
private:
common::Monitor *monitor_{nullptr};
TrainParam const *param_{nullptr};
std::shared_ptr<common::ColumnSampler> col_sampler_;
std::unique_ptr<HistMultiEvaluator> evaluator_;
// Histogram builder for each target.
std::vector<HistogramBuilder<MultiExpandEntry>> histogram_builder_;
Context const *ctx_{nullptr};
// Partitioner for each data batch.
std::vector<CommonRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
RegTree const *p_last_tree_{nullptr};
ObjInfo const *task_{nullptr};
public:
void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree,
std::vector<MultiExpandEntry> const &applied) {
monitor_->Start(__func__);
std::size_t page_id{0};
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(this->param_))) {
this->partitioner_.at(page_id).UpdatePosition(this->ctx_, page, applied, p_tree);
page_id++;
}
monitor_->Stop(__func__);
}
void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
this->evaluator_->ApplyTreeSplit(candidate, p_tree);
}
void InitData(DMatrix *p_fmat, RegTree const *p_tree) {
monitor_->Start(__func__);
std::size_t page_id = 0;
bst_bin_t n_total_bins = 0;
partitioner_.clear();
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
} else {
CHECK_EQ(n_total_bins, page.cut.TotalBins());
}
partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->IsColumnSplit());
page_id++;
}
bst_target_t n_targets = p_tree->NumTargets();
histogram_builder_.clear();
for (std::size_t i = 0; i < n_targets; ++i) {
histogram_builder_.emplace_back();
histogram_builder_.back().Reset(n_total_bins, HistBatch(param_), ctx_->Threads(), page_id,
collective::IsDistributed(), p_fmat->IsColumnSplit());
}
evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
p_last_tree_ = p_tree;
monitor_->Stop(__func__);
}
MultiExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView<GradientPair const> gpair,
RegTree *p_tree) {
monitor_->Start(__func__);
MultiExpandEntry best;
best.nid = RegTree::kRoot;
best.depth = 0;
auto n_targets = p_tree->NumTargets();
linalg::Matrix<GradientPairPrecise> root_sum_tloc =
linalg::Empty<GradientPairPrecise>(ctx_, ctx_->Threads(), n_targets);
CHECK_EQ(root_sum_tloc.Shape(1), gpair.Shape(1));
auto h_root_sum_tloc = root_sum_tloc.HostView();
common::ParallelFor(gpair.Shape(0), ctx_->Threads(), [&](auto i) {
for (bst_target_t t{0}; t < n_targets; ++t) {
h_root_sum_tloc(omp_get_thread_num(), t) += GradientPairPrecise{gpair(i, t)};
}
});
// Aggregate to the first row.
auto root_sum = h_root_sum_tloc.Slice(0, linalg::All());
for (std::int32_t tidx{1}; tidx < ctx_->Threads(); ++tidx) {
for (bst_target_t t{0}; t < n_targets; ++t) {
root_sum(t) += h_root_sum_tloc(tidx, t);
}
}
CHECK(root_sum.CContiguous());
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(root_sum.Values().data()), root_sum.Size() * 2);
std::vector<MultiExpandEntry> nodes{best};
std::size_t i = 0;
auto space = ConstructHistSpace(partitioner_, nodes);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
for (bst_target_t t{0}; t < n_targets; ++t) {
auto t_gpair = gpair.Slice(linalg::All(), t);
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes, {}, t_gpair.Values());
}
i++;
}
auto weight = evaluator_->InitRoot(root_sum);
auto weight_t = weight.HostView();
std::transform(linalg::cbegin(weight_t), linalg::cend(weight_t), linalg::begin(weight_t),
[&](float w) { return w * param_->learning_rate; });
p_tree->SetLeaf(RegTree::kRoot, weight_t);
std::vector<common::HistCollection const *> hists;
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
hists.push_back(&histogram_builder_[t].Histogram());
}
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, &nodes);
break;
}
monitor_->Stop(__func__);
return nodes.front();
}
void BuildHistogram(DMatrix *p_fmat, RegTree const *p_tree,
std::vector<MultiExpandEntry> const &valid_candidates,
linalg::MatrixView<GradientPair const> gpair) {
monitor_->Start(__func__);
std::vector<MultiExpandEntry> nodes_to_build;
std::vector<MultiExpandEntry> nodes_to_sub;
for (auto const &c : valid_candidates) {
auto left_nidx = p_tree->LeftChild(c.nid);
auto right_nidx = p_tree->RightChild(c.nid);
auto build_nidx = left_nidx;
auto subtract_nidx = right_nidx;
auto lit =
common::MakeIndexTransformIter([&](auto i) { return c.split.left_sum[i].GetHess(); });
auto left_sum = std::accumulate(lit, lit + c.split.left_sum.size(), .0);
auto rit =
common::MakeIndexTransformIter([&](auto i) { return c.split.right_sum[i].GetHess(); });
auto right_sum = std::accumulate(rit, rit + c.split.right_sum.size(), .0);
auto fewer_right = right_sum < left_sum;
if (fewer_right) {
std::swap(build_nidx, subtract_nidx);
}
nodes_to_build.emplace_back(build_nidx, p_tree->GetDepth(build_nidx));
nodes_to_sub.emplace_back(subtract_nidx, p_tree->GetDepth(subtract_nidx));
}
std::size_t i = 0;
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
for (std::size_t t = 0; t < p_tree->NumTargets(); ++t) {
auto t_gpair = gpair.Slice(linalg::All(), t);
// Make sure the gradient matrix is f-order.
CHECK(t_gpair.Contiguous());
histogram_builder_[t].BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(),
nodes_to_build, nodes_to_sub, t_gpair.Values());
}
i++;
}
monitor_->Stop(__func__);
}
void EvaluateSplits(DMatrix *p_fmat, RegTree const *p_tree,
std::vector<MultiExpandEntry> *best_splits) {
monitor_->Start(__func__);
std::vector<common::HistCollection const *> hists;
for (bst_target_t t{0}; t < p_tree->NumTargets(); ++t) {
hists.push_back(&histogram_builder_[t].Histogram());
}
for (auto const &gmat : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
evaluator_->EvaluateSplits(*p_tree, hists, gmat.cut, best_splits);
break;
}
monitor_->Stop(__func__);
}
void LeafPartition(RegTree const &tree, linalg::MatrixView<GradientPair const> gpair,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!task_->UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
part.LeafPartition(ctx_, tree, gpair, p_out_position);
}
monitor_->Stop(__func__);
}
public:
explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param,
std::shared_ptr<common::ColumnSampler> column_sampler,
ObjInfo const *task, common::Monitor *monitor)
: monitor_{monitor},
param_{param},
col_sampler_{std::move(column_sampler)},
evaluator_{std::make_unique<HistMultiEvaluator>(ctx, info, param, col_sampler_)},
ctx_{ctx},
task_{task} {
monitor_->Init(__func__);
}
};
class HistBuilder {
private:
common::Monitor *monitor_;
@ -155,8 +364,7 @@ class HistBuilder {
// initialize temp data structure
void InitData(DMatrix *fmat, RegTree const *p_tree) {
monitor_->Start(__func__);
size_t page_id{0};
std::size_t page_id{0};
bst_bin_t n_total_bins{0};
partitioner_.clear();
for (auto const &page : fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
@ -195,7 +403,7 @@ class HistBuilder {
RegTree *p_tree) {
CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0));
size_t page_id = 0;
std::size_t page_id = 0;
auto space = ConstructHistSpace(partitioner_, {node});
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
std::vector<CPUExpandEntry> nodes_to_build{node};
@ -214,13 +422,13 @@ class HistBuilder {
* of gradient histogram is equal to snode[nid]
*/
auto const &gmat = *(p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_)).begin());
std::vector<uint32_t> const &row_ptr = gmat.cut.Ptrs();
std::vector<std::uint32_t> const &row_ptr = gmat.cut.Ptrs();
CHECK_GE(row_ptr.size(), 2);
uint32_t const ibegin = row_ptr[0];
uint32_t const iend = row_ptr[1];
std::uint32_t const ibegin = row_ptr[0];
std::uint32_t const iend = row_ptr[1];
auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot];
auto begin = hist.data();
for (uint32_t i = ibegin; i < iend; ++i) {
for (std::uint32_t i = ibegin; i < iend; ++i) {
GradientPairPrecise const &et = begin[i];
grad_stat.Add(et.GetGrad(), et.GetHess());
}
@ -259,7 +467,7 @@ class HistBuilder {
std::vector<CPUExpandEntry> nodes_to_build(valid_candidates.size());
std::vector<CPUExpandEntry> nodes_to_sub(valid_candidates.size());
size_t n_idx = 0;
std::size_t n_idx = 0;
for (auto const &c : valid_candidates) {
auto left_nidx = (*p_tree)[c.nid].LeftChild();
auto right_nidx = (*p_tree)[c.nid].RightChild();
@ -275,7 +483,7 @@ class HistBuilder {
n_idx++;
}
size_t page_id{0};
std::size_t page_id{0};
auto space = ConstructHistSpace(partitioner_, nodes_to_build);
for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
histogram_builder_->BuildHist(page_id, space, gidx, p_tree,
@ -311,11 +519,12 @@ class HistBuilder {
/*! \brief construct a tree using quantized feature values */
class QuantileHistMaker : public TreeUpdater {
std::unique_ptr<HistBuilder> p_impl_;
std::unique_ptr<HistBuilder> p_impl_{nullptr};
std::unique_ptr<MultiTargetHistBuilder> p_mtimpl_{nullptr};
std::shared_ptr<common::ColumnSampler> column_sampler_ =
std::make_shared<common::ColumnSampler>();
common::Monitor monitor_;
ObjInfo const *task_;
ObjInfo const *task_{nullptr};
public:
explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task)
@ -332,7 +541,10 @@ class QuantileHistMaker : public TreeUpdater {
const std::vector<RegTree *> &trees) override {
if (trees.front()->IsMultiTarget()) {
CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented();
LOG(FATAL) << "Not implemented.";
if (!p_mtimpl_) {
this->p_mtimpl_ = std::make_unique<MultiTargetHistBuilder>(
ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_);
}
} else {
if (!p_impl_) {
p_impl_ =
@ -355,13 +567,14 @@ class QuantileHistMaker : public TreeUpdater {
for (auto tree_it = trees.begin(); tree_it != trees.end(); ++tree_it) {
if (need_copy()) {
// Copy gradient into buffer for sampling.
// Copy gradient into buffer for sampling. This converts C-order to F-order.
std::copy(linalg::cbegin(h_gpair), linalg::cend(h_gpair), linalg::begin(h_sample_out));
}
SampleGradient(ctx_, *param, h_sample_out);
auto *h_out_position = &out_position[tree_it - trees.begin()];
if ((*tree_it)->IsMultiTarget()) {
LOG(FATAL) << "Not implemented.";
UpdateTree<MultiExpandEntry>(&monitor_, h_sample_out, p_mtimpl_.get(), p_fmat, param,
h_out_position, *tree_it);
} else {
UpdateTree<CPUExpandEntry>(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param,
h_out_position, *tree_it);
@ -372,6 +585,9 @@ class QuantileHistMaker : public TreeUpdater {
bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView<float> out_preds) override {
if (p_impl_) {
return p_impl_->UpdatePredictionCache(data, out_preds);
} else if (p_mtimpl_) {
// Not yet supported.
return false;
} else {
return false;
}
@ -383,6 +599,6 @@ class QuantileHistMaker : public TreeUpdater {
XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker")
.describe("Grow tree using quantized histogram.")
.set_body([](Context const *ctx, ObjInfo const *task) {
return new QuantileHistMaker(ctx, task);
return new QuantileHistMaker{ctx, task};
});
} // namespace xgboost::tree

View File

@ -3,7 +3,7 @@ import os
import subprocess
import sys
from multiprocessing import Pool, cpu_count
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple
from pylint import epylint
from test_utils import PY_PACKAGE, ROOT, cd, print_time, record_time
@ -15,7 +15,10 @@ SRCPATH = os.path.normpath(
@record_time
def run_black(rel_path: str) -> bool:
def run_black(rel_path: str, fix: bool) -> bool:
if fix:
cmd = ["black", "-q", rel_path]
else:
cmd = ["black", "-q", "--check", rel_path]
ret = subprocess.run(cmd).returncode
if ret != 0:
@ -31,7 +34,10 @@ Please run the following command on your machine to address the formatting error
@record_time
def run_isort(rel_path: str) -> bool:
def run_isort(rel_path: str, fix: bool) -> bool:
if fix:
cmd = ["isort", f"--src={SRCPATH}", "--profile=black", rel_path]
else:
cmd = ["isort", f"--src={SRCPATH}", "--check", "--profile=black", rel_path]
ret = subprocess.run(cmd).returncode
if ret != 0:
@ -132,7 +138,7 @@ def run_pylint() -> bool:
def main(args: argparse.Namespace) -> None:
if args.format == 1:
black_results = [
run_black(path)
run_black(path, args.fix)
for path in [
# core
"python-package/",
@ -166,7 +172,7 @@ def main(args: argparse.Namespace) -> None:
sys.exit(-1)
isort_results = [
run_isort(path)
run_isort(path, args.fix)
for path in [
# core
"python-package/",
@ -230,6 +236,11 @@ if __name__ == "__main__":
parser.add_argument("--format", type=int, choices=[0, 1], default=1)
parser.add_argument("--type-check", type=int, choices=[0, 1], default=1)
parser.add_argument("--pylint", type=int, choices=[0, 1], default=1)
parser.add_argument(
"--fix",
action="store_true",
help="Fix the formatting issues instead of emitting an error.",
)
args = parser.parse_args()
try:
main(args)

View File

@ -412,7 +412,7 @@ std::pair<Json, Json> TestModelSlice(std::string booster) {
j++;
}
// CHECK sliced model doesn't have dependency on old one
// CHECK sliced model doesn't have dependency on the old one
learner.reset();
CHECK_EQ(sliced->GetNumFeature(), kCols);

View File

@ -473,7 +473,7 @@ inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint
int32_t device = Context::kCpuId) {
size_t shape[1]{1};
LearnerModelParam mparam(n_features, linalg::Tensor<float, 1>{{base_score}, shape, device},
n_groups, 1, MultiStrategy::kComposite);
n_groups, 1, MultiStrategy::kOneOutputPerTree);
return mparam;
}

View File

@ -428,7 +428,7 @@ void TestVectorLeafPrediction(Context const *ctx) {
LearnerModelParam mparam{static_cast<bst_feature_t>(kCols),
linalg::Vector<float>{{0.5}, {1}, Context::kCpuId}, 1, 3,
MultiStrategy::kMonolithic};
MultiStrategy::kMultiOutputTree};
std::vector<std::unique_ptr<RegTree>> trees;
trees.emplace_back(new RegTree{mparam.LeafLength(), mparam.num_feature});

View File

@ -124,11 +124,11 @@ TEST(MultiStrategy, Configure) {
auto p_fmat = RandomDataGenerator{12ul, 3ul, 0.0}.GenerateDMatrix();
p_fmat->Info().labels.Reshape(p_fmat->Info().num_row_, 2);
std::unique_ptr<Learner> learner{Learner::Create({p_fmat})};
learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "2"}});
learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "2"}});
learner->Configure();
ASSERT_EQ(learner->Groups(), 2);
learner->SetParams(Args{{"multi_strategy", "monolithic"}, {"num_target", "0"}});
learner->SetParams(Args{{"multi_strategy", "multi_output_tree"}, {"num_target", "0"}});
ASSERT_THROW({ learner->Configure(); }, dmlc::Error);
}
} // namespace xgboost

View File

@ -116,7 +116,7 @@ def test_with_mq2008(objective, metric) -> None:
x_valid,
y_valid,
qid_valid,
) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
) = tm.data.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank")))
if metric.find("map") != -1 or objective.find("map") != -1:
y_train[y_train <= 1] = 0.0

View File

@ -32,6 +32,19 @@ def train_result(param, dmat: xgb.DMatrix, num_rounds: int) -> dict:
return result
class TestGPUUpdatersMulti:
@given(
hist_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, max_examples=50, print_blob=True)
def test_hist(self, param, num_rounds, dataset):
param["tree_method"] = "gpu_hist"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
class TestGPUUpdaters:
cputest = test_up.TestTreeMethod()
@ -101,7 +114,7 @@ class TestGPUUpdaters:
) -> None:
cat_parameters.update(hist_parameters)
dataset = tm.TestDataset(
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
"ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse"
)
cat_parameters["tree_method"] = "gpu_hist"
results = train_result(cat_parameters, dataset.get_dmat(), 16)

View File

@ -15,13 +15,17 @@ rng = np.random.RandomState(1994)
def json_model(model_path: str, parameters: dict) -> dict:
X = np.random.random((10, 3))
y = np.random.randint(2, size=(10,))
datasets = pytest.importorskip("sklearn.datasets")
X, y = datasets.make_classification(64, n_features=8, n_classes=3, n_informative=6)
if parameters.get("objective", None) == "multi:softmax":
parameters["num_class"] = 3
dm1 = xgb.DMatrix(X, y)
bst = xgb.train(parameters, dm1)
bst.save_model(model_path)
if model_path.endswith("ubj"):
import ubjson
with open(model_path, "rb") as ubjfd:
@ -326,6 +330,8 @@ class TestModels:
from_ubjraw = xgb.Booster()
from_ubjraw.load_model(ubj_raw)
if parameters.get("multi_strategy", None) != "multi_output_tree":
# old binary model is not supported.
old_from_json = from_jraw.save_raw(raw_format="deprecated")
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
@ -335,15 +341,32 @@ class TestModels:
pretty = json.dumps(json.loads(raw_json), indent=2) + "\n\n"
bst.load_model(bytearray(pretty, encoding="ascii"))
if parameters.get("multi_strategy", None) != "multi_output_tree":
# old binary model is not supported.
old_from_json = from_jraw.save_raw(raw_format="deprecated")
old_from_ubj = from_ubjraw.save_raw(raw_format="deprecated")
assert old_from_json == old_from_ubj
rng = np.random.default_rng()
X = rng.random(size=from_jraw.num_features() * 10).reshape(
(10, from_jraw.num_features())
)
predt_from_jraw = from_jraw.predict(xgb.DMatrix(X))
predt_from_bst = bst.predict(xgb.DMatrix(X))
np.testing.assert_allclose(predt_from_jraw, predt_from_bst)
@pytest.mark.parametrize("ext", ["json", "ubj"])
def test_model_json_io(self, ext: str) -> None:
parameters = {"booster": "gbtree", "tree_method": "hist"}
self.run_model_json_io(parameters, ext)
parameters = {
"booster": "gbtree",
"tree_method": "hist",
"multi_strategy": "multi_output_tree",
"objective": "multi:softmax",
}
self.run_model_json_io(parameters, ext)
parameters = {"booster": "gblinear"}
self.run_model_json_io(parameters, ext)
parameters = {"booster": "dart", "tree_method": "hist"}

View File

@ -465,7 +465,7 @@ class TestCallbacks:
assert os.path.exists(os.path.join(tmpdir, "model_" + str(i) + ".pkl"))
def test_callback_list(self):
X, y = tm.get_california_housing()
X, y = tm.data.get_california_housing()
m = xgb.DMatrix(X, y)
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
for i in range(4):

View File

@ -82,7 +82,7 @@ class TestRanking:
"""
cls.dpath = 'demo/rank/'
(x_train, y_train, qid_train, x_test, y_test, qid_test,
x_valid, y_valid, qid_valid) = tm.get_mq2008(cls.dpath)
x_valid, y_valid, qid_valid) = tm.data.get_mq2008(cls.dpath)
# instantiate the matrices
cls.dtrain = xgboost.DMatrix(x_train, y_train)

View File

@ -11,6 +11,7 @@ from xgboost import testing as tm
from xgboost.testing.params import (
cat_parameter_strategy,
exact_parameter_strategy,
hist_multi_parameter_strategy,
hist_parameter_strategy,
)
from xgboost.testing.updater import check_init_estimation, check_quantile_loss
@ -18,11 +19,70 @@ from xgboost.testing.updater import check_init_estimation, check_quantile_loss
def train_result(param, dmat, num_rounds):
result = {}
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
evals_result=result)
booster = xgb.train(
param,
dmat,
num_rounds,
[(dmat, "train")],
verbose_eval=False,
evals_result=result,
)
assert booster.num_features() == dmat.num_col()
assert booster.num_boosted_rounds() == num_rounds
assert booster.feature_names == dmat.feature_names
assert booster.feature_types == dmat.feature_types
return result
class TestTreeMethodMulti:
@given(
exact_parameter_strategy, strategies.integers(1, 20), tm.multi_dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_exact(self, param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "exact"
param = dataset.set_params(param)
result = train_result(param, dataset.get_dmat(), num_rounds)
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_approx(self, param, hist_param, num_rounds, dataset):
param["tree_method"] = "approx"
param = dataset.set_params(param)
param.update(hist_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
@given(
exact_parameter_strategy,
hist_multi_parameter_strategy,
strategies.integers(1, 20),
tm.multi_dataset_strategy,
)
@settings(deadline=None, print_blob=True)
def test_hist(
self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset
) -> None:
if dataset.name.endswith("-l1"):
return
param["tree_method"] = "hist"
param = dataset.set_params(param)
param.update(hist_param)
result = train_result(param, dataset.get_dmat(), num_rounds)
note(result)
assert tm.non_increasing(result["train"][dataset.metric])
class TestTreeMethod:
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1
@ -77,10 +137,14 @@ class TestTreeMethod:
# Second prune should not change the tree
assert after_prune == second_prune
@given(exact_parameter_strategy, hist_parameter_strategy, strategies.integers(1, 20),
tm.dataset_strategy)
@given(
exact_parameter_strategy,
hist_parameter_strategy,
strategies.integers(1, 20),
tm.dataset_strategy
)
@settings(deadline=None, print_blob=True)
def test_hist(self, param, hist_param, num_rounds, dataset):
def test_hist(self, param: dict, hist_param: dict, num_rounds: int, dataset: tm.TestDataset) -> None:
param['tree_method'] = 'hist'
param = dataset.set_params(param)
param.update(hist_param)
@ -88,23 +152,6 @@ class TestTreeMethod:
note(result)
assert tm.non_increasing(result['train'][dataset.metric])
@given(tm.sparse_datasets_strategy)
@settings(deadline=None, print_blob=True)
def test_sparse(self, dataset):
param = {"tree_method": "hist", "max_bin": 64}
hist_result = train_result(param, dataset.get_dmat(), 16)
note(hist_result)
assert tm.non_increasing(hist_result['train'][dataset.metric])
param = {"tree_method": "approx", "max_bin": 64}
approx_result = train_result(param, dataset.get_dmat(), 16)
note(approx_result)
assert tm.non_increasing(approx_result['train'][dataset.metric])
np.testing.assert_allclose(
hist_result["train"]["rmse"], approx_result["train"]["rmse"]
)
def test_hist_categorical(self):
# hist must be same as exact on all-categorial data
dpath = 'demo/data/'
@ -143,6 +190,23 @@ class TestTreeMethod:
w = [0, 0, 1, 0]
model.fit(X, y, sample_weight=w)
@given(tm.sparse_datasets_strategy)
@settings(deadline=None, print_blob=True)
def test_sparse(self, dataset):
param = {"tree_method": "hist", "max_bin": 64}
hist_result = train_result(param, dataset.get_dmat(), 16)
note(hist_result)
assert tm.non_increasing(hist_result['train'][dataset.metric])
param = {"tree_method": "approx", "max_bin": 64}
approx_result = train_result(param, dataset.get_dmat(), 16)
note(approx_result)
assert tm.non_increasing(approx_result['train'][dataset.metric])
np.testing.assert_allclose(
hist_result["train"]["rmse"], approx_result["train"]["rmse"]
)
def run_invalid_category(self, tree_method: str) -> None:
rng = np.random.default_rng()
# too large
@ -365,7 +429,7 @@ class TestTreeMethod:
) -> None:
cat_parameters.update(hist_parameters)
dataset = tm.TestDataset(
"ames_housing", tm.get_ames_housing, "reg:squarederror", "rmse"
"ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse"
)
cat_parameters["tree_method"] = tree_method
results = train_result(cat_parameters, dataset.get_dmat(), 16)

View File

@ -1168,7 +1168,7 @@ def test_dask_aft_survival() -> None:
def test_dask_ranking(client: "Client") -> None:
dpath = "demo/rank/"
mq2008 = tm.get_mq2008(dpath)
mq2008 = tm.data.get_mq2008(dpath)
data = []
for d in mq2008:
if isinstance(d, scipy.sparse.csr_matrix):