Init estimation for regression. (#8272)
This commit is contained in:
parent
1b58d81315
commit
badeff1d74
57
.github/workflows/python_tests.yml
vendored
57
.github/workflows/python_tests.yml
vendored
@ -213,3 +213,60 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pytest -s -v -rxXs --durations=0 ./tests/python
|
||||
|
||||
python-tests-on-ubuntu:
|
||||
name: Test XGBoost Python package on ${{ matrix.config.os }}
|
||||
runs-on: ${{ matrix.config.os }}
|
||||
timeout-minutes: 90
|
||||
strategy:
|
||||
matrix:
|
||||
config:
|
||||
- {os: ubuntu-latest, python-version: "3.8"}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: 'true'
|
||||
|
||||
- uses: mamba-org/provision-with-micromamba@f347426e5745fe3dfc13ec5baf20496990d0281f # v14
|
||||
with:
|
||||
cache-downloads: true
|
||||
cache-env: true
|
||||
environment-name: linux_cpu_test
|
||||
environment-file: tests/ci_build/conda_env/linux_cpu_test.yml
|
||||
|
||||
- name: Display Conda env
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
conda info
|
||||
conda list
|
||||
|
||||
- name: Build XGBoost on Ubuntu
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -GNinja -DCMAKE_PREFIX_PATH=$CONDA_PREFIX
|
||||
ninja
|
||||
|
||||
- name: Install Python package
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
cd python-package
|
||||
python --version
|
||||
python setup.py install
|
||||
|
||||
- name: Test Python package
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pytest -s -v -rxXs --durations=0 ./tests/python
|
||||
|
||||
- name: Test Dask Interface
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pytest -s -v -rxXs --durations=0 ./tests/test_distributed/test_with_dask
|
||||
|
||||
- name: Test PySpark Interface
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
pytest -s -v -rxXs --durations=0 ./tests/test_distributed/test_with_spark
|
||||
|
||||
@ -320,7 +320,7 @@ test_that("prediction in early-stopping xgb.cv works", {
|
||||
expect_output(
|
||||
cv <- xgb.cv(param, dtrain, nfold = 5, eta = 0.1, nrounds = 20,
|
||||
early_stopping_rounds = 5, maximize = FALSE, stratified = FALSE,
|
||||
prediction = TRUE)
|
||||
prediction = TRUE, base_score = 0.5)
|
||||
, "Stopping. Best iteration")
|
||||
|
||||
expect_false(is.null(cv$best_iteration))
|
||||
|
||||
@ -27,11 +27,13 @@ if (isTRUE(VCD_AVAILABLE)) {
|
||||
# binary
|
||||
bst.Tree <- xgboost(data = sparse_matrix, label = label, max_depth = 9,
|
||||
eta = 1, nthread = 2, nrounds = nrounds, verbose = 0,
|
||||
objective = "binary:logistic", booster = "gbtree")
|
||||
objective = "binary:logistic", booster = "gbtree",
|
||||
base_score = 0.5)
|
||||
|
||||
bst.GLM <- xgboost(data = sparse_matrix, label = label,
|
||||
eta = 1, nthread = 1, nrounds = nrounds, verbose = 0,
|
||||
objective = "binary:logistic", booster = "gblinear")
|
||||
objective = "binary:logistic", booster = "gblinear",
|
||||
base_score = 0.5)
|
||||
|
||||
feature.names <- colnames(sparse_matrix)
|
||||
}
|
||||
@ -360,7 +362,8 @@ test_that("xgb.importance works with and without feature names", {
|
||||
m <- xgboost::xgboost(
|
||||
data = as.matrix(data.frame(x = c(0, 1))),
|
||||
label = c(1, 2),
|
||||
nrounds = 1
|
||||
nrounds = 1,
|
||||
base_score = 0.5
|
||||
)
|
||||
df <- xgb.model.dt.tree(model = m)
|
||||
expect_equal(df$Feature, "Leaf")
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
'''
|
||||
"""
|
||||
Demo for using feature weight to change column sampling
|
||||
=======================================================
|
||||
|
||||
.. versionadded:: 1.3.0
|
||||
'''
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
@ -13,10 +13,10 @@ from matplotlib import pyplot as plt
|
||||
import xgboost
|
||||
|
||||
|
||||
def main(args):
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
kRows = 1000
|
||||
kRows = 4196
|
||||
kCols = 10
|
||||
|
||||
X = rng.randn(kRows, kCols)
|
||||
@ -28,26 +28,32 @@ def main(args):
|
||||
dtrain = xgboost.DMatrix(X, y)
|
||||
dtrain.set_info(feature_weights=fw)
|
||||
|
||||
bst = xgboost.train({'tree_method': 'hist',
|
||||
'colsample_bynode': 0.2},
|
||||
dtrain, num_boost_round=10,
|
||||
evals=[(dtrain, 'd')])
|
||||
# Perform column sampling for each node split evaluation, the sampling process is
|
||||
# weighted by feature weights.
|
||||
bst = xgboost.train(
|
||||
{"tree_method": "hist", "colsample_bynode": 0.2},
|
||||
dtrain,
|
||||
num_boost_round=10,
|
||||
evals=[(dtrain, "d")],
|
||||
)
|
||||
feature_map = bst.get_fscore()
|
||||
|
||||
# feature zero has 0 weight
|
||||
assert feature_map.get('f0', None) is None
|
||||
assert max(feature_map.values()) == feature_map.get('f9')
|
||||
assert feature_map.get("f0", None) is None
|
||||
assert max(feature_map.values()) == feature_map.get("f9")
|
||||
|
||||
if args.plot:
|
||||
xgboost.plot_importance(bst)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--plot',
|
||||
"--plot",
|
||||
type=int,
|
||||
default=1,
|
||||
help='Set to 0 to disable plotting the evaluation history.')
|
||||
help="Set to 0 to disable plotting the evaluation history.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
@ -12,10 +12,15 @@ import xgboost as xgb
|
||||
if __name__ == "__main__":
|
||||
print("Parallel Parameter optimization")
|
||||
X, y = fetch_california_housing(return_X_y=True)
|
||||
xgb_model = xgb.XGBRegressor(n_jobs=multiprocessing.cpu_count() // 2)
|
||||
clf = GridSearchCV(xgb_model, {'max_depth': [2, 4, 6],
|
||||
'n_estimators': [50, 100, 200]}, verbose=1,
|
||||
n_jobs=2)
|
||||
xgb_model = xgb.XGBRegressor(
|
||||
n_jobs=multiprocessing.cpu_count() // 2, tree_method="hist"
|
||||
)
|
||||
clf = GridSearchCV(
|
||||
xgb_model,
|
||||
{"max_depth": [2, 4, 6], "n_estimators": [50, 100, 200]},
|
||||
verbose=1,
|
||||
n_jobs=2,
|
||||
)
|
||||
clf.fit(X, y)
|
||||
print(clf.best_score_)
|
||||
print(clf.best_params_)
|
||||
|
||||
@ -261,10 +261,10 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
||||
"eta" -> "1",
|
||||
"max_depth" -> "6",
|
||||
"silent" -> "1",
|
||||
"base_score" -> 0.5,
|
||||
"objective" -> "binary:logistic",
|
||||
"tree_method" -> treeMethod,
|
||||
"max_bin" -> 16)
|
||||
|
||||
val model1 = ScalaXGBoost.train(trainingDM, paramMap, round)
|
||||
val prediction1 = model1.predict(testDM)
|
||||
|
||||
@ -453,5 +453,4 @@ class XGBoostClassifierSuite extends FunSuite with PerTest with TmpFolderPerSuit
|
||||
assert(!compareTwoFiles(new File(modelJsonPath, "data/XGBoostClassificationModel").getPath,
|
||||
nativeUbjModelPath))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@ -1078,7 +1078,7 @@ class XGBModel(XGBModelBase):
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
) -> ArrayLike:
|
||||
"""Predict with `X`. If the model is trained with early stopping, then `best_iteration`
|
||||
is used automatically. For tree models, when data is on GPU, like cupy array or
|
||||
cuDF dataframe and `predictor` is not specified, the prediction is run on GPU
|
||||
@ -1528,7 +1528,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
validate_features: bool = True,
|
||||
base_margin: Optional[ArrayLike] = None,
|
||||
iteration_range: Optional[Tuple[int, int]] = None,
|
||||
) -> np.ndarray:
|
||||
) -> ArrayLike:
|
||||
with config_context(verbosity=self.verbosity):
|
||||
class_probs = super().predict(
|
||||
X=X,
|
||||
|
||||
54
python-package/xgboost/testing/dask.py
Normal file
54
python-package/xgboost/testing/dask.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Tests for dask shared by different test modules."""
|
||||
import numpy as np
|
||||
from dask import array as da
|
||||
from distributed import Client
|
||||
from xgboost.testing.updater import get_basescore
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
def check_init_estimation_clf(tree_method: str, client: Client) -> None:
|
||||
"""Test init estimation for classsifier."""
|
||||
from sklearn.datasets import make_classification
|
||||
|
||||
X, y = make_classification(n_samples=4096 * 2, n_features=32, random_state=1994)
|
||||
clf = xgb.XGBClassifier(n_estimators=1, max_depth=1, tree_method=tree_method)
|
||||
clf.fit(X, y)
|
||||
base_score = get_basescore(clf)
|
||||
|
||||
dx = da.from_array(X).rechunk(chunks=(32, None))
|
||||
dy = da.from_array(y).rechunk(chunks=(32,))
|
||||
dclf = xgb.dask.DaskXGBClassifier(
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method
|
||||
)
|
||||
dclf.client = client
|
||||
dclf.fit(dx, dy)
|
||||
dbase_score = get_basescore(dclf)
|
||||
np.testing.assert_allclose(base_score, dbase_score)
|
||||
|
||||
|
||||
def check_init_estimation_reg(tree_method: str, client: Client) -> None:
|
||||
"""Test init estimation for regressor."""
|
||||
from sklearn.datasets import make_regression
|
||||
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
X, y = make_regression(n_samples=4096 * 2, n_features=32, random_state=1994)
|
||||
reg = xgb.XGBRegressor(n_estimators=1, max_depth=1, tree_method=tree_method)
|
||||
reg.fit(X, y)
|
||||
base_score = get_basescore(reg)
|
||||
|
||||
dx = da.from_array(X).rechunk(chunks=(32, None))
|
||||
dy = da.from_array(y).rechunk(chunks=(32,))
|
||||
dreg = xgb.dask.DaskXGBRegressor(
|
||||
n_estimators=1, max_depth=1, tree_method=tree_method
|
||||
)
|
||||
dreg.client = client
|
||||
dreg.fit(dx, dy)
|
||||
dbase_score = get_basescore(dreg)
|
||||
np.testing.assert_allclose(base_score, dbase_score)
|
||||
|
||||
|
||||
def check_init_estimation(tree_method: str, client: Client) -> None:
|
||||
"""Test init estimation."""
|
||||
check_init_estimation_reg(tree_method, client)
|
||||
check_init_estimation_clf(tree_method, client)
|
||||
70
python-package/xgboost/testing/updater.py
Normal file
70
python-package/xgboost/testing/updater.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Tests for updaters."""
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
import xgboost as xgb
|
||||
|
||||
|
||||
def get_basescore(model: xgb.XGBModel) -> float:
|
||||
"""Get base score from an XGBoost sklearn estimator."""
|
||||
base_score = float(
|
||||
json.loads(model.get_booster().save_config())["learner"]["learner_model_param"][
|
||||
"base_score"
|
||||
]
|
||||
)
|
||||
return base_score
|
||||
|
||||
|
||||
def check_init_estimation(tree_method: str) -> None:
|
||||
"""Test for init estimation."""
|
||||
from sklearn.datasets import (
|
||||
make_classification,
|
||||
make_multilabel_classification,
|
||||
make_regression,
|
||||
)
|
||||
|
||||
def run_reg(X: np.ndarray, y: np.ndarray) -> None: # pylint: disable=invalid-name
|
||||
reg = xgb.XGBRegressor(tree_method=tree_method, max_depth=1, n_estimators=1)
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
base_score_0 = get_basescore(reg)
|
||||
score_0 = reg.evals_result()["validation_0"]["rmse"][0]
|
||||
|
||||
reg = xgb.XGBRegressor(
|
||||
tree_method=tree_method, max_depth=1, n_estimators=1, boost_from_average=0
|
||||
)
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
base_score_1 = get_basescore(reg)
|
||||
score_1 = reg.evals_result()["validation_0"]["rmse"][0]
|
||||
assert not np.isclose(base_score_0, base_score_1)
|
||||
assert score_0 < score_1 # should be better
|
||||
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
X, y = make_regression(n_samples=4096, random_state=17)
|
||||
run_reg(X, y)
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
X, y = make_regression(n_samples=4096, n_targets=3, random_state=17)
|
||||
run_reg(X, y)
|
||||
|
||||
def run_clf(X: np.ndarray, y: np.ndarray) -> None: # pylint: disable=invalid-name
|
||||
clf = xgb.XGBClassifier(tree_method=tree_method, max_depth=1, n_estimators=1)
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
base_score_0 = get_basescore(clf)
|
||||
score_0 = clf.evals_result()["validation_0"]["logloss"][0]
|
||||
|
||||
clf = xgb.XGBClassifier(
|
||||
tree_method=tree_method, max_depth=1, n_estimators=1, boost_from_average=0
|
||||
)
|
||||
clf.fit(X, y, eval_set=[(X, y)])
|
||||
base_score_1 = get_basescore(clf)
|
||||
score_1 = clf.evals_result()["validation_0"]["logloss"][0]
|
||||
assert not np.isclose(base_score_0, base_score_1)
|
||||
assert score_0 < score_1 # should be better
|
||||
|
||||
# pylint: disable=unbalanced-tuple-unpacking
|
||||
X, y = make_classification(n_samples=4096, random_state=17)
|
||||
run_clf(X, y)
|
||||
X, y = make_multilabel_classification(
|
||||
n_samples=4096, n_labels=3, n_classes=5, random_state=17
|
||||
)
|
||||
run_clf(X, y)
|
||||
@ -119,7 +119,7 @@ class RabitCommunicator : public Communicator {
|
||||
}
|
||||
|
||||
template <typename DType, std::enable_if_t<std::is_floating_point<DType>::value> * = nullptr>
|
||||
void DoBitwiseAllReduce(void *send_receive_buffer, std::size_t count, Operation op) {
|
||||
void DoBitwiseAllReduce(void *, std::size_t, Operation) {
|
||||
LOG(FATAL) << "Floating point types do not support bitwise operations.";
|
||||
}
|
||||
|
||||
|
||||
@ -684,7 +684,7 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
|
||||
}
|
||||
}
|
||||
|
||||
void MetaInfo::Validate(int32_t device) const {
|
||||
void MetaInfo::Validate(std::int32_t device) const {
|
||||
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
|
||||
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
|
||||
<< "Size of weights must equal to number of groups when ranking "
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
#include "../common/hist_util.h"
|
||||
#include "../common/numeric.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform_iterator.h" // MakeIndexTransformIter
|
||||
#include "../common/transform_iterator.h" // common::MakeIndexTransformIter
|
||||
#include "adapter.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "xgboost/base.h"
|
||||
|
||||
@ -190,6 +190,32 @@ struct LearnerModelParamLegacy : public dmlc::Parameter<LearnerModelParamLegacy>
|
||||
}
|
||||
return dmlc::Parameter<LearnerModelParamLegacy>::UpdateAllowUnknown(kwargs);
|
||||
}
|
||||
// sanity check
|
||||
void Validate() {
|
||||
if (!collective::IsDistributed()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::array<std::int32_t, 6> data;
|
||||
std::size_t pos{0};
|
||||
std::memcpy(data.data() + pos, &base_score, sizeof(base_score));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &num_feature, sizeof(num_feature));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &num_class, sizeof(num_class));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &num_target, sizeof(num_target));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &major_version, sizeof(major_version));
|
||||
pos += 1;
|
||||
std::memcpy(data.data() + pos, &minor_version, sizeof(minor_version));
|
||||
|
||||
std::array<std::int32_t, 6> sync;
|
||||
std::copy(data.cbegin(), data.cend(), sync.begin());
|
||||
collective::Broadcast(sync.data(), sync.size(), 0);
|
||||
CHECK(std::equal(data.cbegin(), data.cend(), sync.cbegin()))
|
||||
<< "Different model parameter across workers.";
|
||||
}
|
||||
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(LearnerModelParamLegacy) {
|
||||
@ -391,6 +417,7 @@ class LearnerConfiguration : public Learner {
|
||||
}
|
||||
// Update the shared model parameter
|
||||
this->ConfigureModelParamWithoutBaseScore();
|
||||
mparam_.Validate();
|
||||
}
|
||||
CHECK(!std::isnan(mparam_.base_score));
|
||||
CHECK(!std::isinf(mparam_.base_score));
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "../common/stats.h"
|
||||
#include "../common/threading_utils.h"
|
||||
#include "../common/transform.h"
|
||||
#include "../tree/fit_stump.h" // FitStump
|
||||
#include "./regression_loss.h"
|
||||
#include "adaptive.h"
|
||||
#include "xgboost/base.h"
|
||||
@ -53,6 +54,31 @@ void CheckRegInputs(MetaInfo const& info, HostDeviceVector<bst_float> const& pre
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
class RegInitEstimation : public ObjFunction {
|
||||
void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) const override {
|
||||
CheckInitInputs(info);
|
||||
// Avoid altering any state in child objective.
|
||||
HostDeviceVector<float> dummy_predt(info.labels.Size(), 0.0f, this->ctx_->gpu_id);
|
||||
HostDeviceVector<GradientPair> gpair(info.labels.Size(), GradientPair{}, this->ctx_->gpu_id);
|
||||
|
||||
Json config{Object{}};
|
||||
this->SaveConfig(&config);
|
||||
|
||||
std::unique_ptr<ObjFunction> new_obj{
|
||||
ObjFunction::Create(get<String const>(config["name"]), this->ctx_)};
|
||||
new_obj->LoadConfig(config);
|
||||
new_obj->GetGradient(dummy_predt, info, 0, &gpair);
|
||||
bst_target_t n_targets = this->Targets(info);
|
||||
linalg::Vector<float> leaf_weight;
|
||||
tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight);
|
||||
|
||||
// workaround, we don't support multi-target due to binary model serialization for
|
||||
// base margin.
|
||||
common::Mean(this->ctx_, leaf_weight, base_score);
|
||||
this->PredTransform(base_score->Data());
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||
#endif // defined(XGBOOST_USE_CUDA)
|
||||
@ -67,7 +93,7 @@ struct RegLossParam : public XGBoostParameter<RegLossParam> {
|
||||
};
|
||||
|
||||
template<typename Loss>
|
||||
class RegLossObj : public ObjFunction {
|
||||
class RegLossObj : public RegInitEstimation {
|
||||
protected:
|
||||
HostDeviceVector<float> additional_input_;
|
||||
|
||||
@ -214,7 +240,7 @@ XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
return new RegLossObj<LinearSquareLoss>(); });
|
||||
// End deprecated
|
||||
|
||||
class PseudoHuberRegression : public ObjFunction {
|
||||
class PseudoHuberRegression : public RegInitEstimation {
|
||||
PesudoHuberParam param_;
|
||||
|
||||
public:
|
||||
@ -289,7 +315,7 @@ struct PoissonRegressionParam : public XGBoostParameter<PoissonRegressionParam>
|
||||
};
|
||||
|
||||
// poisson regression for count
|
||||
class PoissonRegression : public ObjFunction {
|
||||
class PoissonRegression : public RegInitEstimation {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
@ -384,7 +410,7 @@ XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
||||
|
||||
|
||||
// cox regression for survival data (negative values mean they are censored)
|
||||
class CoxRegression : public ObjFunction {
|
||||
class CoxRegression : public RegInitEstimation {
|
||||
public:
|
||||
void Configure(Args const&) override {}
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
@ -481,7 +507,7 @@ XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
|
||||
.set_body([]() { return new CoxRegression(); });
|
||||
|
||||
// gamma regression
|
||||
class GammaRegression : public ObjFunction {
|
||||
class GammaRegression : public RegInitEstimation {
|
||||
public:
|
||||
void Configure(Args const&) override {}
|
||||
ObjInfo Task() const override { return ObjInfo::kRegression; }
|
||||
@ -572,7 +598,7 @@ struct TweedieRegressionParam : public XGBoostParameter<TweedieRegressionParam>
|
||||
};
|
||||
|
||||
// tweedie regression
|
||||
class TweedieRegression : public ObjFunction {
|
||||
class TweedieRegression : public RegInitEstimation {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
|
||||
@ -36,10 +36,10 @@ RUN git clone -b v1.49.1 https://github.com/grpc/grpc.git \
|
||||
rm -rf grpc
|
||||
|
||||
# Create new Conda environment
|
||||
COPY conda_env/cpu_test.yml /scripts/
|
||||
RUN mamba env create -n cpu_test --file=/scripts/cpu_test.yml && \
|
||||
COPY conda_env/linux_cpu_test.yml /scripts/
|
||||
RUN mamba env create -n linux_cpu_test --file=/scripts/linux_cpu_test.yml && \
|
||||
mamba clean --all && \
|
||||
conda run --no-capture-output -n cpu_test pip install buildkite-test-collector
|
||||
conda run --no-capture-output -n linux_cpu_test pip install buildkite-test-collector
|
||||
|
||||
# Install lightweight sudo (not bound to TTY)
|
||||
RUN set -ex; \
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
name: cpu_test
|
||||
name: linux_cpu_test
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- python=3.8
|
||||
- cmake
|
||||
- c-compiler
|
||||
- cxx-compiler
|
||||
- ninja
|
||||
- pip
|
||||
- wheel
|
||||
- pyyaml
|
||||
@ -33,7 +37,7 @@ dependencies:
|
||||
- pyarrow
|
||||
- protobuf
|
||||
- cloudpickle
|
||||
- shap
|
||||
- shap>=0.41
|
||||
- modin
|
||||
- pip:
|
||||
- datatable
|
||||
@ -146,13 +146,17 @@ def main(args: argparse.Namespace) -> None:
|
||||
"tests/python/test_data_iterator.py",
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_quantile_dmatrix.py",
|
||||
"tests/python/test_tree_regularization.py",
|
||||
"tests/python-gpu/test_gpu_data_iterator.py",
|
||||
"tests/ci_build/lint_python.py",
|
||||
"tests/test_distributed/test_with_spark/",
|
||||
"tests/test_distributed/test_gpu_with_spark/",
|
||||
# demo
|
||||
"demo/json-model/json_parser.py",
|
||||
"demo/guide-python/cat_in_the_dat.py",
|
||||
"demo/guide-python/categorical.py",
|
||||
"demo/guide-python/feature_weights.py",
|
||||
"demo/guide-python/sklearn_parallel.py",
|
||||
"demo/guide-python/spark_estimator_examples.py",
|
||||
# CI
|
||||
"tests/ci_build/lint_python.py",
|
||||
@ -194,6 +198,7 @@ def main(args: argparse.Namespace) -> None:
|
||||
"demo/json-model/json_parser.py",
|
||||
"demo/guide-python/external_memory.py",
|
||||
"demo/guide-python/cat_in_the_dat.py",
|
||||
"demo/guide-python/feature_weights.py",
|
||||
# tests
|
||||
"tests/python/test_dt.py",
|
||||
"tests/python/test_data_iterator.py",
|
||||
|
||||
@ -76,7 +76,7 @@ case "$suite" in
|
||||
;;
|
||||
|
||||
cpu)
|
||||
source activate cpu_test
|
||||
source activate linux_cpu_test
|
||||
set -x
|
||||
install_xgboost
|
||||
export RAY_OBJECT_STORE_ALLOW_SLOW_STORAGE=1
|
||||
|
||||
@ -224,5 +224,6 @@ Arrow specification.'''
|
||||
dtrain = dmatrix_from_cupy(
|
||||
np.float32, xgb.DeviceQuantileDMatrix, np.nan)
|
||||
with pytest.raises(xgb.core.XGBoostError):
|
||||
xgb.train({'tree_method': 'gpu_hist', 'gpu_id': 1},
|
||||
dtrain, num_boost_round=10)
|
||||
xgb.train(
|
||||
{'tree_method': 'gpu_hist', 'gpu_id': 1}, dtrain, num_boost_round=10
|
||||
)
|
||||
|
||||
@ -5,6 +5,7 @@ import numpy as np
|
||||
import pytest
|
||||
from hypothesis import assume, given, note, settings, strategies
|
||||
from xgboost.testing.params import cat_parameter_strategy, hist_parameter_strategy
|
||||
from xgboost.testing.updater import check_init_estimation
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
@ -172,24 +173,25 @@ class TestGPUUpdaters:
|
||||
kCols = 100
|
||||
|
||||
X = np.empty((kRows, kCols))
|
||||
y = np.empty((kRows))
|
||||
y = np.empty((kRows,))
|
||||
|
||||
dtrain = xgb.DMatrix(X, y)
|
||||
|
||||
bst = xgb.train({'verbosity': 2,
|
||||
'tree_method': 'gpu_hist',
|
||||
'gpu_id': 0},
|
||||
bst = xgb.train(
|
||||
{"verbosity": 2, "tree_method": "gpu_hist", "gpu_id": 0},
|
||||
dtrain,
|
||||
verbose_eval=True,
|
||||
num_boost_round=6,
|
||||
evals=[(dtrain, 'Train')])
|
||||
evals=[(dtrain, 'Train')]
|
||||
)
|
||||
|
||||
kRows = 100
|
||||
X = np.random.randn(kRows, kCols)
|
||||
|
||||
dtest = xgb.DMatrix(X)
|
||||
predictions = bst.predict(dtest)
|
||||
np.testing.assert_allclose(predictions, 0.5, 1e-6)
|
||||
# non-distributed, 0.0 is returned due to base_score estimation with 0 gradient.
|
||||
np.testing.assert_allclose(predictions, 0.0, 1e-6)
|
||||
|
||||
@pytest.mark.mgpu
|
||||
@given(tm.dataset_strategy, strategies.integers(0, 10))
|
||||
@ -204,3 +206,6 @@ class TestGPUUpdaters:
|
||||
@pytest.mark.parametrize("weighted", [True, False])
|
||||
def test_adaptive(self, weighted) -> None:
|
||||
self.cputest.run_adaptive("gpu_hist", weighted)
|
||||
|
||||
def test_init_estimation(self) -> None:
|
||||
check_init_estimation("gpu_hist")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from xgboost.testing.updater import get_basescore
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
@ -11,16 +12,12 @@ class TestEarlyStopping:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping_nonparallel(self):
|
||||
from sklearn.datasets import load_digits
|
||||
try:
|
||||
from sklearn.model_selection import train_test_split
|
||||
except ImportError:
|
||||
from sklearn.cross_validation import train_test_split
|
||||
|
||||
digits = load_digits(n_class=2)
|
||||
X = digits['data']
|
||||
y = digits['target']
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y,
|
||||
random_state=0)
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
||||
clf1 = xgb.XGBClassifier(learning_rate=0.1)
|
||||
clf1.fit(X_train, y_train, early_stopping_rounds=5, eval_metric="auc",
|
||||
eval_set=[(X_test, y_test)])
|
||||
@ -31,9 +28,23 @@ class TestEarlyStopping:
|
||||
assert clf1.best_score == clf2.best_score
|
||||
assert clf1.best_score != 1
|
||||
# check overfit
|
||||
clf3 = xgb.XGBClassifier(learning_rate=0.1)
|
||||
clf3.fit(X_train, y_train, early_stopping_rounds=10, eval_metric="auc",
|
||||
eval_set=[(X_test, y_test)])
|
||||
clf3 = xgb.XGBClassifier(
|
||||
learning_rate=0.1,
|
||||
eval_metric="auc",
|
||||
early_stopping_rounds=10
|
||||
)
|
||||
clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||
base_score = get_basescore(clf3)
|
||||
assert 0.53 > base_score > 0.5
|
||||
|
||||
clf3 = xgb.XGBClassifier(
|
||||
learning_rate=0.1,
|
||||
base_score=.5,
|
||||
eval_metric="auc",
|
||||
early_stopping_rounds=10
|
||||
)
|
||||
clf3.fit(X_train, y_train, eval_set=[(X_test, y_test)])
|
||||
|
||||
assert clf3.best_score == 1
|
||||
|
||||
def evalerror(self, preds, dtrain):
|
||||
|
||||
@ -9,11 +9,13 @@ train_data = xgb.DMatrix(np.array([[1]]), label=np.array([1]))
|
||||
class TestTreeRegularization:
|
||||
def test_alpha(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'verbosity': 0,
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1,
|
||||
'lambda': 0,
|
||||
'alpha': 0.1
|
||||
"tree_method": "exact",
|
||||
"verbosity": 0,
|
||||
"objective": "reg:squarederror",
|
||||
"eta": 1,
|
||||
"lambda": 0,
|
||||
"alpha": 0.1,
|
||||
"base_score": 0.5,
|
||||
}
|
||||
|
||||
model = xgb.train(params, train_data, 1)
|
||||
@ -27,11 +29,13 @@ class TestTreeRegularization:
|
||||
|
||||
def test_lambda(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'verbosity': 0,
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1,
|
||||
'lambda': 1,
|
||||
'alpha': 0
|
||||
"tree_method": "exact",
|
||||
"verbosity": 0,
|
||||
"objective": "reg:squarederror",
|
||||
"eta": 1,
|
||||
"lambda": 1,
|
||||
"alpha": 0,
|
||||
"base_score": 0.5,
|
||||
}
|
||||
|
||||
model = xgb.train(params, train_data, 1)
|
||||
@ -45,11 +49,13 @@ class TestTreeRegularization:
|
||||
|
||||
def test_alpha_and_lambda(self):
|
||||
params = {
|
||||
'tree_method': 'exact', 'verbosity': 1,
|
||||
'objective': 'reg:squarederror',
|
||||
'eta': 1,
|
||||
'lambda': 1,
|
||||
'alpha': 0.1
|
||||
"tree_method": "exact",
|
||||
"verbosity": 1,
|
||||
"objective": "reg:squarederror",
|
||||
"eta": 1,
|
||||
"lambda": 1,
|
||||
"alpha": 0.1,
|
||||
"base_score": 0.5,
|
||||
}
|
||||
|
||||
model = xgb.train(params, train_data, 1)
|
||||
|
||||
@ -10,6 +10,7 @@ from xgboost.testing.params import (
|
||||
exact_parameter_strategy,
|
||||
hist_parameter_strategy,
|
||||
)
|
||||
from xgboost.testing.updater import check_init_estimation
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
@ -449,3 +450,6 @@ class TestTreeMethod:
|
||||
)
|
||||
def test_adaptive(self, tree_method, weighted) -> None:
|
||||
self.run_adaptive(tree_method, weighted)
|
||||
|
||||
def test_init_estimation(self) -> None:
|
||||
check_init_estimation("hist")
|
||||
|
||||
@ -9,6 +9,7 @@ except Exception:
|
||||
shap = None
|
||||
pass
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(shap is None, reason="Requires shap package")
|
||||
|
||||
|
||||
@ -16,11 +17,16 @@ pytestmark = pytest.mark.skipif(shap is None, reason="Requires shap package")
|
||||
# Changes in binary format may cause problems
|
||||
def test_with_shap():
|
||||
from sklearn.datasets import fetch_california_housing
|
||||
|
||||
X, y = fetch_california_housing(return_X_y=True)
|
||||
dtrain = xgb.DMatrix(X, label=y)
|
||||
model = xgb.train({"learning_rate": 0.01}, dtrain, 10)
|
||||
explainer = shap.TreeExplainer(model)
|
||||
shap_values = explainer.shap_values(X)
|
||||
margin = model.predict(dtrain, output_margin=True)
|
||||
assert np.allclose(np.sum(shap_values, axis=len(shap_values.shape) - 1),
|
||||
margin - explainer.expected_value, 1e-3, 1e-3)
|
||||
assert np.allclose(
|
||||
np.sum(shap_values, axis=len(shap_values.shape) - 1),
|
||||
margin - explainer.expected_value,
|
||||
1e-3,
|
||||
1e-3,
|
||||
)
|
||||
|
||||
@ -9,6 +9,7 @@ import numpy as np
|
||||
import pytest
|
||||
from sklearn.utils.estimator_checks import parametrize_with_checks
|
||||
from xgboost.testing.shared import get_feature_weights, validate_data_initialization
|
||||
from xgboost.testing.updater import get_basescore
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
@ -196,19 +197,22 @@ def test_stacking_classification():
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
|
||||
clf.fit(X_train, y_train).score(X_test, y_test)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**tm.no_pandas())
|
||||
def test_feature_importances_weight():
|
||||
from sklearn.datasets import load_digits
|
||||
|
||||
digits = load_digits(n_class=2)
|
||||
y = digits['target']
|
||||
X = digits['data']
|
||||
y = digits["target"]
|
||||
X = digits["data"]
|
||||
|
||||
xgb_model = xgb.XGBClassifier(random_state=0,
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0,
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="weight").fit(X, y)
|
||||
importance_type="weight",
|
||||
base_score=0.5,
|
||||
).fit(X, y)
|
||||
|
||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.00833333, 0.,
|
||||
0., 0., 0., 0., 0., 0., 0., 0.025, 0.14166667, 0., 0., 0.,
|
||||
0., 0., 0., 0.00833333, 0.25833333, 0., 0., 0., 0.,
|
||||
@ -223,16 +227,22 @@ def test_feature_importances_weight():
|
||||
import pandas as pd
|
||||
y = pd.Series(digits['target'])
|
||||
X = pd.DataFrame(digits['data'])
|
||||
xgb_model = xgb.XGBClassifier(random_state=0,
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0,
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="weight").fit(X, y)
|
||||
base_score=.5,
|
||||
importance_type="weight"
|
||||
).fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
xgb_model = xgb.XGBClassifier(random_state=0,
|
||||
xgb_model = xgb.XGBClassifier(
|
||||
random_state=0,
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="weight").fit(X, y)
|
||||
importance_type="weight",
|
||||
base_score=.5,
|
||||
).fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
@ -274,6 +284,7 @@ def test_feature_importances_gain():
|
||||
random_state=0, tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="gain",
|
||||
base_score=0.5,
|
||||
).fit(X, y)
|
||||
|
||||
exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
|
||||
@ -296,6 +307,7 @@ def test_feature_importances_gain():
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="gain",
|
||||
base_score=0.5,
|
||||
).fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
@ -304,6 +316,7 @@ def test_feature_importances_gain():
|
||||
tree_method="exact",
|
||||
learning_rate=0.1,
|
||||
importance_type="gain",
|
||||
base_score=0.5,
|
||||
).fit(X, y)
|
||||
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)
|
||||
|
||||
@ -593,18 +606,21 @@ def test_split_value_histograms():
|
||||
|
||||
digits_2class = load_digits(n_class=2)
|
||||
|
||||
X = digits_2class['data']
|
||||
y = digits_2class['target']
|
||||
X = digits_2class["data"]
|
||||
y = digits_2class["target"]
|
||||
|
||||
dm = xgb.DMatrix(X, label=y)
|
||||
params = {'max_depth': 6, 'eta': 0.01, 'verbosity': 0,
|
||||
'objective': 'binary:logistic'}
|
||||
params = {
|
||||
"max_depth": 6,
|
||||
"eta": 0.01,
|
||||
"verbosity": 0,
|
||||
"objective": "binary:logistic",
|
||||
"base_score": 0.5,
|
||||
}
|
||||
|
||||
gbdt = xgb.train(params, dm, num_boost_round=10)
|
||||
assert gbdt.get_split_value_histogram("not_there",
|
||||
as_pandas=True).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("not_there",
|
||||
as_pandas=False).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
|
||||
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
|
||||
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
|
||||
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
|
||||
@ -748,11 +764,7 @@ def test_sklearn_get_default_params():
|
||||
cls = xgb.XGBClassifier()
|
||||
assert cls.get_params()["base_score"] is None
|
||||
cls.fit(X[:4, ...], y[:4, ...])
|
||||
base_score = float(
|
||||
json.loads(cls.get_booster().save_config())["learner"]["learner_model_param"][
|
||||
"base_score"
|
||||
]
|
||||
)
|
||||
base_score = get_basescore(cls)
|
||||
np.testing.assert_equal(base_score, 0.5)
|
||||
|
||||
|
||||
|
||||
@ -42,6 +42,7 @@ try:
|
||||
from dask import array as da
|
||||
from dask.distributed import Client
|
||||
from dask_cuda import LocalCUDACluster
|
||||
from xgboost.testing.dask import check_init_estimation
|
||||
|
||||
from xgboost import dask as dxgb
|
||||
except ImportError:
|
||||
@ -220,6 +221,9 @@ class TestDistributedGPU:
|
||||
y = dd.from_array(y_, chunksize=50).map_partitions(cudf.from_pandas)
|
||||
run_boost_from_prediction_multi_class(X, y, "gpu_hist", local_cuda_client)
|
||||
|
||||
def test_init_estimation(self, local_cuda_client: Client) -> None:
|
||||
check_init_estimation("gpu_hist", local_cuda_client)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_dask_cudf())
|
||||
def test_dask_dataframe(self, local_cuda_client: Client) -> None:
|
||||
run_with_dask_dataframe(dxgb.DaskDMatrix, local_cuda_client)
|
||||
|
||||
@ -12,7 +12,7 @@ from itertools import starmap
|
||||
from math import ceil
|
||||
from operator import attrgetter, getitem
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, Optional, Tuple, Type, Union
|
||||
from typing import Any, Dict, Generator, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import hypothesis
|
||||
import numpy as np
|
||||
@ -32,7 +32,7 @@ from xgboost.testing.shared import (
|
||||
import xgboost as xgb
|
||||
from xgboost import testing as tm
|
||||
|
||||
pytestmark = [tm.timeout(30), pytest.mark.skipif(**tm.no_dask())]
|
||||
pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_dask())]
|
||||
|
||||
import dask
|
||||
import dask.array as da
|
||||
@ -40,6 +40,7 @@ import dask.dataframe as dd
|
||||
from distributed import Client, LocalCluster
|
||||
from toolz import sliding_window # dependency of dask
|
||||
from xgboost.dask import DaskDMatrix
|
||||
from xgboost.testing.dask import check_init_estimation
|
||||
|
||||
dask.config.set({"distributed.scheduler.allowed-failures": False})
|
||||
|
||||
@ -52,8 +53,10 @@ else:
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def cluster() -> Generator:
|
||||
n_threads = os.cpu_count()
|
||||
assert n_threads is not None
|
||||
with LocalCluster(
|
||||
n_workers=2, threads_per_worker=2, dashboard_address=":0"
|
||||
n_workers=2, threads_per_worker=n_threads // 2, dashboard_address=":0"
|
||||
) as dask_cluster:
|
||||
yield dask_cluster
|
||||
|
||||
@ -151,12 +154,15 @@ def deterministic_persist_per_worker(df: dd.DataFrame, client: "Client") -> dd.D
|
||||
return df2
|
||||
|
||||
|
||||
Margin = TypeVar("Margin", dd.DataFrame, dd.Series, None)
|
||||
|
||||
|
||||
def deterministic_repartition(
|
||||
client: Client,
|
||||
X: dd.DataFrame,
|
||||
y: dd.Series,
|
||||
m: Optional[Union[dd.DataFrame, dd.Series]],
|
||||
) -> Tuple[dd.DataFrame, dd.Series, Optional[Union[dd.DataFrame, dd.Series]]]:
|
||||
m: Margin,
|
||||
) -> Tuple[dd.DataFrame, dd.Series, Margin]:
|
||||
# force repartition the data to avoid non-deterministic result
|
||||
if any(X.map_partitions(lambda x: _is_cudf_df(x)).compute()):
|
||||
# dask_cudf seems to be doing fine for now
|
||||
@ -474,14 +480,20 @@ def run_boost_from_prediction(
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
predictions_1: dd.Series = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.dask.DaskXGBClassifier(
|
||||
model_2 = xgb.dask.DaskXGBClassifier(
|
||||
learning_rate=0.3, n_estimators=8, tree_method=tree_method, max_bin=512
|
||||
)
|
||||
X, y, _ = deterministic_repartition(client, X, y, None)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2: dd.Series = cls_2.predict(X)
|
||||
model_2.fit(X=X, y=y)
|
||||
predictions_2: dd.Series = model_2.predict(X)
|
||||
|
||||
assert np.all(predictions_1.compute() == predictions_2.compute())
|
||||
predt_1 = predictions_1.compute()
|
||||
predt_2 = predictions_2.compute()
|
||||
if hasattr(predt_1, "to_numpy"):
|
||||
predt_1 = predt_1.to_numpy()
|
||||
if hasattr(predt_2, "to_numpy"):
|
||||
predt_2 = predt_2.to_numpy()
|
||||
np.testing.assert_allclose(predt_1, predt_2, atol=1e-5)
|
||||
|
||||
margined = xgb.dask.DaskXGBClassifier(n_estimators=4)
|
||||
X, y, margin = deterministic_repartition(client, X, y, margin)
|
||||
@ -706,6 +718,7 @@ def run_dask_classifier(
|
||||
def test_dask_classifier(model: str, client: "Client") -> None:
|
||||
X, y, w = generate_array(with_weights=True)
|
||||
y = (y * 10).astype(np.int32)
|
||||
assert w is not None
|
||||
run_dask_classifier(X, y, w, model, None, client, 10)
|
||||
|
||||
y_bin = y.copy()
|
||||
@ -1386,16 +1399,22 @@ class TestWithDask:
|
||||
else:
|
||||
w = None
|
||||
|
||||
m = xgb.dask.DaskDMatrix(
|
||||
client, data=X, label=y, weight=w)
|
||||
history = xgb.dask.train(client, params=params, dtrain=m,
|
||||
m = xgb.dask.DaskDMatrix(client, data=X, label=y, weight=w)
|
||||
history = xgb.dask.train(
|
||||
client,
|
||||
params=params,
|
||||
dtrain=m,
|
||||
num_boost_round=num_rounds,
|
||||
evals=[(m, 'train')])['history']
|
||||
evals=[(m, "train")],
|
||||
)["history"]
|
||||
note(history)
|
||||
history = history['train'][dataset.metric]
|
||||
history = history["train"][dataset.metric]
|
||||
|
||||
def is_stump() -> bool:
|
||||
return params["max_depth"] == 1 or params["max_leaves"] == 1
|
||||
def is_stump():
|
||||
return (
|
||||
params.get("max_depth", None) == 1
|
||||
or params.get("max_leaves", None) == 1
|
||||
)
|
||||
|
||||
def minimum_bin() -> bool:
|
||||
return "max_bin" in params and params["max_bin"] == 2
|
||||
@ -1410,6 +1429,10 @@ class TestWithDask:
|
||||
else:
|
||||
assert tm.non_increasing(history)
|
||||
# Make sure that it's decreasing
|
||||
if is_stump():
|
||||
# we might have already got the best score with base_score.
|
||||
assert history[-1] <= history[0]
|
||||
else:
|
||||
assert history[-1] < history[0]
|
||||
|
||||
@given(params=hist_parameter_strategy,
|
||||
@ -1646,13 +1669,17 @@ class TestWithDask:
|
||||
|
||||
results_custom = reg.evals_result()
|
||||
|
||||
reg = xgb.dask.DaskXGBRegressor(n_estimators=rounds, tree_method='hist')
|
||||
reg = xgb.dask.DaskXGBRegressor(
|
||||
n_estimators=rounds, tree_method="hist", base_score=0.5
|
||||
)
|
||||
reg.fit(X, y, eval_set=[(X, y)])
|
||||
results_native = reg.evals_result()
|
||||
|
||||
np.testing.assert_allclose(results_custom['validation_0']['rmse'],
|
||||
results_native['validation_0']['rmse'])
|
||||
tm.non_increasing(results_native['validation_0']['rmse'])
|
||||
np.testing.assert_allclose(
|
||||
results_custom["validation_0"]["rmse"],
|
||||
results_native["validation_0"]["rmse"],
|
||||
)
|
||||
tm.non_increasing(results_native["validation_0"]["rmse"])
|
||||
|
||||
def test_no_duplicated_partition(self) -> None:
|
||||
'''Assert each worker has the correct amount of data, and DMatrix initialization doesn't
|
||||
@ -1994,6 +2021,10 @@ def test_parallel_submit_multi_clients() -> None:
|
||||
assert f.result().get_booster().num_boosted_rounds() == i + 1
|
||||
|
||||
|
||||
def test_init_estimation(client: Client) -> None:
|
||||
check_init_estimation("hist", client)
|
||||
|
||||
|
||||
class TestDaskCallbacks:
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_early_stopping(self, client: "Client") -> None:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
import numpy as np
|
||||
@ -216,7 +215,7 @@ class XgboostLocalClusterTestCase(SparkLocalClusterTestCase):
|
||||
],
|
||||
)
|
||||
self.reg_best_score_eval = 5.239e-05
|
||||
self.reg_best_score_weight_and_eval = 4.810e-05
|
||||
self.reg_best_score_weight_and_eval = 4.850e-05
|
||||
|
||||
def test_regressor_basic_with_params(self):
|
||||
regressor = SparkXGBRegressor(**self.reg_params)
|
||||
|
||||
@ -4,16 +4,15 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from io import StringIO
|
||||
|
||||
import pytest
|
||||
from six import StringIO
|
||||
|
||||
from xgboost import testing as tm
|
||||
|
||||
pytestmark = [pytest.mark.skipif(**tm.no_spark())]
|
||||
|
||||
|
||||
from pyspark.sql import SparkSession, SQLContext
|
||||
from pyspark.sql import SparkSession
|
||||
from xgboost.spark.utils import _get_default_params_from_func
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user