Make QuantileDMatrix default to sklearn esitmators. (#8220)

This commit is contained in:
Jiaming Yuan 2022-09-13 13:52:19 +08:00 committed by GitHub
parent a2686543a9
commit bdf265076d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 91 additions and 39 deletions

View File

@ -726,10 +726,9 @@ def _create_quantile_dmatrix(
if parts is None:
msg = f"worker {worker.address} has an empty DMatrix."
LOGGER.warning(msg)
import cupy
d = QuantileDMatrix(
cupy.zeros((0, 0)),
numpy.empty((0, 0)),
feature_names=feature_names,
feature_types=feature_types,
max_bin=max_bin,
@ -1544,15 +1543,21 @@ def inplace_predict( # pylint: disable=unused-argument
async def _async_wrap_evaluation_matrices(
client: Optional["distributed.Client"], **kwargs: Any
client: Optional["distributed.Client"],
tree_method: Optional[str],
max_bin: Optional[int],
**kwargs: Any,
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
"""A switch function for async environment."""
def _inner(**kwargs: Any) -> DaskDMatrix:
m = DaskDMatrix(client=client, **kwargs)
return m
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
if tree_method in ("hist", "gpu_hist"):
return DaskQuantileDMatrix(
client=client, ref=ref, max_bin=max_bin, **kwargs
)
return DaskDMatrix(client=client, **kwargs)
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_dispatch, **kwargs)
train_dmatrix = await train_dmatrix
if evals is None:
return train_dmatrix, evals
@ -1756,6 +1761,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
client=self.client,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
@ -1851,6 +1858,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,
@ -2057,6 +2066,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
params = self.get_xgb_params()
dtrain, evals = await _async_wrap_evaluation_matrices(
self.client,
tree_method=self.tree_method,
max_bin=self.max_bin,
X=X,
y=y,
group=None,

View File

@ -38,6 +38,7 @@ from .core import (
Booster,
DMatrix,
Metric,
QuantileDMatrix,
XGBoostError,
_convert_ntree_limit,
_deprecate_positional_args,
@ -430,7 +431,8 @@ def _wrap_evaluation_matrices(
enable_categorical: bool,
feature_types: Optional[FeatureTypes],
) -> Tuple[Any, List[Tuple[Any, str]]]:
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way."""
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the
way."""
train_dmatrix = create_dmatrix(
data=X,
label=y,
@ -442,6 +444,7 @@ def _wrap_evaluation_matrices(
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
ref=None,
)
n_validation = 0 if eval_set is None else len(eval_set)
@ -491,6 +494,7 @@ def _wrap_evaluation_matrices(
missing=missing,
enable_categorical=enable_categorical,
feature_types=feature_types,
ref=train_dmatrix,
)
evals.append(m)
nevals = len(evals)
@ -904,6 +908,17 @@ class XGBModel(XGBModelBase):
return model, metric, params, early_stopping_rounds, callbacks
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
# Use `QuantileDMatrix` to save memory.
if self.tree_method in ("hist", "gpu_hist"):
try:
return QuantileDMatrix(
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
)
except TypeError: # `QuantileDMatrix` supports lesser types than DMatrix
pass
return DMatrix(**kwargs, nthread=self.n_jobs)
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
if evals_result:
self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
@ -996,7 +1011,7 @@ class XGBModel(XGBModelBase):
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
create_dmatrix=self._create_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
@ -1479,7 +1494,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
base_margin_eval_set=base_margin_eval_set,
eval_group=None,
eval_qid=None,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
create_dmatrix=self._create_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
@ -1930,7 +1945,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
base_margin_eval_set=base_margin_eval_set,
eval_group=eval_group,
eval_qid=eval_qid,
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
create_dmatrix=self._create_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)

View File

@ -7,6 +7,7 @@
#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
@ -14,6 +15,38 @@
namespace xgboost {
namespace data {
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing, int nthread,
bst_bin_t max_bin)
: proxy_{proxy}, reset_{reset}, next_{next} {
// fetch the first batch
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
iter.Reset();
bool valid = iter.Next();
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";
auto d = MakeProxy(proxy_)->DeviceIdx();
StringView msg{"All batch should be on the same device."};
if (batch_param_.gpu_id != Context::kCpuId) {
CHECK_EQ(d, batch_param_.gpu_id) << msg;
}
batch_param_ = BatchParam{d, max_bin};
// hardcoded parameter.
batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold();
ctx_.UpdateAllowUnknown(
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
if (ctx_.IsCPU()) {
this->InitFromCPU(iter_handle, missing, ref);
} else {
this->InitFromCUDA(iter_handle, missing, ref);
}
}
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
common::HistogramCuts* p_cuts) {
CHECK(ref_);
@ -199,6 +232,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
if (n_batches == 1) {
this->info_ = std::move(proxy->Info());
this->info_.num_nonzero_ = nnz;
this->info_.num_col_ = n_features; // proxy might be empty.
CHECK_EQ(proxy->Info().labels.Size(), 0);
}
}
@ -210,6 +244,10 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
}
if (param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
"of `DMatrix`.";
}
auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter);

View File

@ -173,8 +173,15 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& para
}
if (!ellpack_ && ghist_) {
ellpack_.reset(new EllpackPage());
// Evaluation QuantileDMatrix initialized from CPU data might not have the correct GPU
// ID.
if (this->ctx_.IsCPU()) {
this->ctx_.gpu_id = param.gpu_id;
this->Info().feature_types.SetDevice(param.gpu_id);
}
if (this->ctx_.IsCPU()) {
this->ctx_.gpu_id = dh::CurrentDevice();
}
this->Info().feature_types.SetDevice(this->ctx_.gpu_id);
*ellpack_->Impl() =
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
}

View File

@ -75,30 +75,7 @@ class IterativeDMatrix : public DMatrix {
explicit IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
bst_bin_t max_bin)
: proxy_{proxy}, reset_{reset}, next_{next} {
// fetch the first batch
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
iter.Reset();
bool valid = iter.Next();
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";
auto d = MakeProxy(proxy_)->DeviceIdx();
if (batch_param_.gpu_id != Context::kCpuId) {
CHECK_EQ(d, batch_param_.gpu_id) << "All batch should be on the same device.";
}
batch_param_ = BatchParam{d, max_bin};
batch_param_.sparse_thresh = 0.2; // default from TrainParam
ctx_.UpdateAllowUnknown(
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
if (ctx_.IsCPU()) {
this->InitFromCPU(iter_handle, missing, ref);
} else {
this->InitFromCUDA(iter_handle, missing, ref);
}
}
bst_bin_t max_bin);
~IterativeDMatrix() override = default;
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }

View File

@ -78,7 +78,9 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
// ------ From CPU quantile histogram -------.
// percentage threshold for treating a feature as sparse
// e.g. 0.2 indicates a feature with fewer than 20% nonzeros is considered sparse
double sparse_threshold;
static constexpr double DftSparseThreshold() { return 0.2; }
double sparse_threshold{DftSparseThreshold()};
// declare the parameters
DMLC_DECLARE_PARAMETER(TrainParam) {
@ -182,7 +184,9 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
"See tutorial for more information");
// ------ From cpu quantile histogram -------.
DMLC_DECLARE_FIELD(sparse_threshold).set_range(0, 1.0).set_default(0.2)
DMLC_DECLARE_FIELD(sparse_threshold)
.set_range(0, 1.0)
.set_default(DftSparseThreshold())
.describe("percentage threshold for treating a feature as sparse");
// add alias of parameters