Enforce correct data shape. (#5191)
* Fix syncing DMatrix columns. * notes for tree method. * Enable feature validation for all interfaces except for jvm. * Better tests for boosting from predictions. * Disable validation on JVM.
This commit is contained in:
parent
8cbcc53ccb
commit
7b65698187
@ -112,18 +112,24 @@ Parameters for Tree Booster
|
||||
|
||||
- The tree construction algorithm used in XGBoost. See description in the `reference paper <http://arxiv.org/abs/1603.02754>`_.
|
||||
- XGBoost supports ``approx``, ``hist`` and ``gpu_hist`` for distributed training. Experimental support for external memory is available for ``approx`` and ``gpu_hist``.
|
||||
- Choices: ``auto``, ``exact``, ``approx``, ``hist``, ``gpu_hist``
|
||||
|
||||
- Choices: ``auto``, ``exact``, ``approx``, ``hist``, ``gpu_hist``, this is a
|
||||
combination of commonly used updaters. For other updaters like ``refresh``, set the
|
||||
parameter ``updater`` directly.
|
||||
|
||||
- ``auto``: Use heuristic to choose the fastest method.
|
||||
|
||||
- For small to medium dataset, exact greedy (``exact``) will be used.
|
||||
- For very large dataset, approximate algorithm (``approx``) will be chosen.
|
||||
- Because old behavior is always use exact greedy in single machine,
|
||||
user will get a message when approximate algorithm is chosen to notify this choice.
|
||||
- For small dataset, exact greedy (``exact``) will be used.
|
||||
- For larger dataset, approximate algorithm (``approx``) will be chosen. It's
|
||||
recommended to try ``hist`` and ``gpu_hist`` for higher performance with large
|
||||
dataset.
|
||||
(``gpu_hist``)has support for ``external memory``.
|
||||
|
||||
- ``exact``: Exact greedy algorithm.
|
||||
- Because old behavior is always use exact greedy in single machine, user will get a
|
||||
message when approximate algorithm is chosen to notify this choice.
|
||||
- ``exact``: Exact greedy algorithm. Enumerates all split candidates.
|
||||
- ``approx``: Approximate greedy algorithm using quantile sketch and gradient histogram.
|
||||
- ``hist``: Fast histogram optimized approximate greedy algorithm. It uses some performance improvements such as bins caching.
|
||||
- ``hist``: Faster histogram optimized approximate greedy algorithm.
|
||||
- ``gpu_hist``: GPU implementation of ``hist`` algorithm.
|
||||
|
||||
* ``sketch_eps`` [default=0.03]
|
||||
|
||||
@ -38,6 +38,11 @@ There are in general two ways that you can control overfitting in XGBoost:
|
||||
- This includes ``subsample`` and ``colsample_bytree``.
|
||||
- You can also reduce stepsize ``eta``. Remember to increase ``num_round`` when you do so.
|
||||
|
||||
***************************
|
||||
Faster training performance
|
||||
***************************
|
||||
There's a parameter called ``tree_method``, set it to ``hist`` or ``gpu_hist`` for faster computation.
|
||||
|
||||
*************************
|
||||
Handle Imbalanced Dataset
|
||||
*************************
|
||||
|
||||
@ -29,6 +29,7 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
size_t gpu_page_size;
|
||||
bool enable_experimental_json_serialization {false};
|
||||
bool validate_parameters {false};
|
||||
bool validate_features {true};
|
||||
|
||||
void CheckDeprecated() {
|
||||
if (this->n_gpus != 0) {
|
||||
@ -73,7 +74,10 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
|
||||
"rabit checkpoints etc.).");
|
||||
DMLC_DECLARE_FIELD(validate_parameters)
|
||||
.set_default(false)
|
||||
.describe("Enable to check whether parameters are used or not.");
|
||||
.describe("Enable checking whether parameters are used or not.");
|
||||
DMLC_DECLARE_FIELD(validate_features)
|
||||
.set_default(false)
|
||||
.describe("Enable validating input DMatrix.");
|
||||
DMLC_DECLARE_FIELD(n_gpus)
|
||||
.set_default(0)
|
||||
.set_range(0, 1)
|
||||
|
||||
@ -49,7 +49,7 @@ public class Booster implements Serializable, KryoSerializable {
|
||||
*/
|
||||
Booster(Map<String, Object> params, DMatrix[] cacheMats) throws XGBoostError {
|
||||
init(cacheMats);
|
||||
setParam("seed", "0");
|
||||
setParam("validate_features", "0");
|
||||
setParams(params);
|
||||
}
|
||||
|
||||
|
||||
@ -71,6 +71,12 @@ class TrainingObserver {
|
||||
auto const& h_vec = vec.HostVector();
|
||||
this->Observe(h_vec, name);
|
||||
}
|
||||
template <typename T>
|
||||
void Observe(HostDeviceVector<T>* vec, std::string name) const {
|
||||
if (XGBOOST_EXPECT(!observe_, true)) { return; }
|
||||
this->Observe(*vec, name);
|
||||
}
|
||||
|
||||
/*\brief Observe objects with `XGBoostParamer' type. */
|
||||
template <typename Parameter,
|
||||
typename std::enable_if<
|
||||
|
||||
@ -295,32 +295,8 @@ void DMatrix::SaveToLocalFile(const std::string& fname) {
|
||||
DMatrix* DMatrix::Create(std::unique_ptr<DataSource<SparsePage>>&& source,
|
||||
const std::string& cache_prefix) {
|
||||
if (cache_prefix.length() == 0) {
|
||||
// FIXME(trivialfis): Currently distcol is broken so we here check for number of rows.
|
||||
// If we bring back column split this check will break.
|
||||
bool is_distributed { rabit::IsDistributed() };
|
||||
if (is_distributed) {
|
||||
auto world_size = rabit::GetWorldSize();
|
||||
auto rank = rabit::GetRank();
|
||||
std::vector<uint64_t> ncols(world_size, 0);
|
||||
ncols[rank] = source->info.num_col_;
|
||||
rabit::Allreduce<rabit::op::Sum>(ncols.data(), ncols.size());
|
||||
auto max_cols = std::max_element(ncols.cbegin(), ncols.cend());
|
||||
auto max_ind = std::distance(ncols.cbegin(), max_cols);
|
||||
// FIXME(trivialfis): This is a hack, we should store a reference to global shape if possible.
|
||||
if (source->info.num_col_ == 0 && source->info.num_row_ == 0) {
|
||||
LOG(WARNING) << "DMatrix at rank: " << rank << " worker is empty.";
|
||||
source->info.num_col_ = *max_cols;
|
||||
}
|
||||
|
||||
// validate the number of columns across all workers.
|
||||
for (size_t i = 0; i < ncols.size(); ++i) {
|
||||
auto v = ncols[i];
|
||||
CHECK(v == 0 || v == *max_cols)
|
||||
<< "DMatrix at rank: " << i << " worker "
|
||||
<< "has different number of columns than rank: " << max_ind << " worker. "
|
||||
<< "(" << v << " vs. " << *max_cols << ")";
|
||||
}
|
||||
}
|
||||
// Data split mode is fixed to be row right now.
|
||||
rabit::Allreduce<rabit::op::Max>(&source->info.num_col_, 1);
|
||||
return new data::SimpleDMatrix(std::move(source));
|
||||
} else {
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
@ -336,6 +312,7 @@ template <typename AdapterT>
|
||||
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
|
||||
const std::string& cache_prefix, size_t page_size ) {
|
||||
if (cache_prefix.length() == 0) {
|
||||
// Data split mode is fixed to be row right now.
|
||||
return new data::SimpleDMatrix(adapter, missing, nthread);
|
||||
} else {
|
||||
#if DMLC_ENABLE_STD_THREAD
|
||||
|
||||
@ -124,7 +124,6 @@ void GBTree::PerformTreeMethodHeuristic(DMatrix* fmat) {
|
||||
return;
|
||||
}
|
||||
|
||||
tparam_.updater_seq = "grow_histmaker,prune";
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(WARNING) <<
|
||||
"Tree method is automatically selected to be 'approx' "
|
||||
|
||||
@ -925,6 +925,25 @@ class LearnerImpl : public Learner {
|
||||
<< "num rows: " << p_fmat->Info().num_row_ << "\n"
|
||||
<< "Number of weights should be equal to number of groups in ranking task.";
|
||||
}
|
||||
|
||||
auto const row_based_split = [this]() {
|
||||
return tparam_.dsplit == DataSplitMode::kRow ||
|
||||
tparam_.dsplit == DataSplitMode::kAuto;
|
||||
};
|
||||
bool const valid_features =
|
||||
!row_based_split() ||
|
||||
(learner_model_param_.num_feature == p_fmat->Info().num_col_);
|
||||
std::string const msg {
|
||||
"Number of columns does not match number of features in booster."
|
||||
};
|
||||
if (generic_parameters_.validate_features) {
|
||||
CHECK_EQ(learner_model_param_.num_feature, p_fmat->Info().num_col_) << msg;
|
||||
} else if (!valid_features) {
|
||||
// Remove this and make the equality check fatal once spark can fix all failing tests.
|
||||
LOG(WARNING) << msg << " "
|
||||
<< "Columns: " << p_fmat->Info().num_col_ << " "
|
||||
<< "Features: " << learner_model_param_.num_feature;
|
||||
}
|
||||
}
|
||||
|
||||
// model parameter
|
||||
|
||||
@ -80,6 +80,10 @@ class ColMaker: public TreeUpdater {
|
||||
void Update(HostDeviceVector<GradientPair> *gpair,
|
||||
DMatrix* dmat,
|
||||
const std::vector<RegTree*> &trees) override {
|
||||
if (rabit::IsDistributed()) {
|
||||
LOG(FATAL) << "Updater `grow_colmaker` or `exact` tree method doesn't "
|
||||
"support distributed training.";
|
||||
}
|
||||
// rescale learning rate according to size of trees
|
||||
float lr = param_.learning_rate;
|
||||
param_.learning_rate = lr / trees.size();
|
||||
|
||||
@ -91,7 +91,6 @@ TEST(Learner, CheckGroup) {
|
||||
}
|
||||
|
||||
TEST(Learner, SLOW_CheckMultiBatch) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
// Create sufficiently large data to make two row pages
|
||||
dmlc::TemporaryDirectory tempdir;
|
||||
const std::string tmp_file = tempdir.path + "/big.libsvm";
|
||||
@ -107,7 +106,7 @@ TEST(Learner, SLOW_CheckMultiBatch) {
|
||||
dmat->Info().SetInfo("label", labels.data(), DataType::kFloat32, num_row);
|
||||
std::vector<std::shared_ptr<DMatrix>> mat{dmat};
|
||||
auto learner = std::unique_ptr<Learner>(Learner::Create(mat));
|
||||
learner->SetParams({Arg{"objective", "binary:logistic"}, Arg{"verbosity", "3"}});
|
||||
learner->SetParams(Args{{"objective", "binary:logistic"}});
|
||||
learner->UpdateOneIter(0, dmat.get());
|
||||
}
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@ import subprocess
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import copy
|
||||
|
||||
import xgboost as xgb
|
||||
from xgboost import XGBClassifier
|
||||
|
||||
@ -13,6 +13,10 @@ class TestGPUPredict(unittest.TestCase):
|
||||
np.random.seed(1)
|
||||
test_num_rows = [10, 1000, 5000]
|
||||
test_num_cols = [10, 50, 500]
|
||||
# This test passes for tree_method=gpu_hist and tree_method=exact. but
|
||||
# for `hist` and `approx` the floating point error accumulates faster
|
||||
# and fails even tol is set to 1e-4. For `hist`, the mismatching rate
|
||||
# with 5000 rows is 0.04.
|
||||
for num_rows in test_num_rows:
|
||||
for num_cols in test_num_cols:
|
||||
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols),
|
||||
@ -27,7 +31,7 @@ class TestGPUPredict(unittest.TestCase):
|
||||
"objective": "binary:logistic",
|
||||
"predictor": "gpu_predictor",
|
||||
'eval_metric': 'auc',
|
||||
'verbosity': '3'
|
||||
'tree_method': 'gpu_hist'
|
||||
}
|
||||
bst = xgb.train(param, dtrain, iterations, evals=watchlist,
|
||||
evals_result=res)
|
||||
@ -43,11 +47,11 @@ class TestGPUPredict(unittest.TestCase):
|
||||
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
|
||||
|
||||
np.testing.assert_allclose(cpu_pred_train, gpu_pred_train,
|
||||
rtol=1e-3)
|
||||
rtol=1e-6)
|
||||
np.testing.assert_allclose(cpu_pred_val, gpu_pred_val,
|
||||
rtol=1e-3)
|
||||
rtol=1e-6)
|
||||
np.testing.assert_allclose(cpu_pred_test, gpu_pred_test,
|
||||
rtol=1e-3)
|
||||
rtol=1e-6)
|
||||
|
||||
def non_decreasing(self, L):
|
||||
return all((x - y) < 0.001 for x, y in zip(L, L[1:]))
|
||||
|
||||
@ -2,9 +2,11 @@ import xgboost as xgb
|
||||
import pytest
|
||||
import sys
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
import testing as tm # noqa
|
||||
import test_with_sklearn as twskl # noqa
|
||||
|
||||
pytestmark = pytest.mark.skipif(**tm.no_sklearn())
|
||||
|
||||
@ -29,3 +31,10 @@ def test_gpu_binary_classification():
|
||||
err = sum(1 for i in range(len(preds))
|
||||
if int(preds[i] > 0.5) != labels[i]) / float(len(preds))
|
||||
assert err < 0.1
|
||||
|
||||
|
||||
class TestGPUBoostFromPrediction(unittest.TestCase):
|
||||
cpu_test = twskl.TestBoostFromPrediction()
|
||||
|
||||
def test_boost_from_prediction_gpu_hist(self):
|
||||
self.cpu_test.run_boost_from_prediction('gpu_hist')
|
||||
|
||||
@ -5,6 +5,7 @@ import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@ -697,21 +698,37 @@ def test_XGBClassifier_resume():
|
||||
assert log_loss1 > log_loss2
|
||||
|
||||
|
||||
def test_boost_from_prediction():
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
class TestBoostFromPrediction(unittest.TestCase):
|
||||
def run_boost_from_prediction(self, tree_method):
|
||||
from sklearn.datasets import load_breast_cancer
|
||||
X, y = load_breast_cancer(return_X_y=True)
|
||||
model_0 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
model_0.fit(X=X, y=y)
|
||||
margin = model_0.predict(X, output_margin=True)
|
||||
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
model_1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=4,
|
||||
tree_method=tree_method)
|
||||
model_1.fit(X=X, y=y, base_margin=margin)
|
||||
predictions_1 = model_1.predict(X, base_margin=margin)
|
||||
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X, base_margin=margin)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
cls_2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8,
|
||||
tree_method=tree_method)
|
||||
cls_2.fit(X=X, y=y)
|
||||
predictions_2 = cls_2.predict(X)
|
||||
assert np.all(predictions_1 == predictions_2)
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_boost_from_prediction_hist(self):
|
||||
self.run_boost_from_prediction('hist')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_boost_from_prediction_approx(self):
|
||||
self.run_boost_from_prediction('approx')
|
||||
|
||||
@pytest.mark.skipif(**tm.no_sklearn())
|
||||
def test_boost_from_prediction_exact(self):
|
||||
self.run_boost_from_prediction('exact')
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user