Make QuantileDMatrix default to sklearn esitmators. (#8220)
This commit is contained in:
parent
a2686543a9
commit
bdf265076d
@ -726,10 +726,9 @@ def _create_quantile_dmatrix(
|
||||
if parts is None:
|
||||
msg = f"worker {worker.address} has an empty DMatrix."
|
||||
LOGGER.warning(msg)
|
||||
import cupy
|
||||
|
||||
d = QuantileDMatrix(
|
||||
cupy.zeros((0, 0)),
|
||||
numpy.empty((0, 0)),
|
||||
feature_names=feature_names,
|
||||
feature_types=feature_types,
|
||||
max_bin=max_bin,
|
||||
@ -1544,15 +1543,21 @@ def inplace_predict( # pylint: disable=unused-argument
|
||||
|
||||
|
||||
async def _async_wrap_evaluation_matrices(
|
||||
client: Optional["distributed.Client"], **kwargs: Any
|
||||
client: Optional["distributed.Client"],
|
||||
tree_method: Optional[str],
|
||||
max_bin: Optional[int],
|
||||
**kwargs: Any,
|
||||
) -> Tuple[DaskDMatrix, Optional[List[Tuple[DaskDMatrix, str]]]]:
|
||||
"""A switch function for async environment."""
|
||||
|
||||
def _inner(**kwargs: Any) -> DaskDMatrix:
|
||||
m = DaskDMatrix(client=client, **kwargs)
|
||||
return m
|
||||
def _dispatch(ref: Optional[DaskDMatrix], **kwargs: Any) -> DaskDMatrix:
|
||||
if tree_method in ("hist", "gpu_hist"):
|
||||
return DaskQuantileDMatrix(
|
||||
client=client, ref=ref, max_bin=max_bin, **kwargs
|
||||
)
|
||||
return DaskDMatrix(client=client, **kwargs)
|
||||
|
||||
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_inner, **kwargs)
|
||||
train_dmatrix, evals = _wrap_evaluation_matrices(create_dmatrix=_dispatch, **kwargs)
|
||||
train_dmatrix = await train_dmatrix
|
||||
if evals is None:
|
||||
return train_dmatrix, evals
|
||||
@ -1756,6 +1761,8 @@ class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
|
||||
params = self.get_xgb_params()
|
||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||
client=self.client,
|
||||
tree_method=self.tree_method,
|
||||
max_bin=self.max_bin,
|
||||
X=X,
|
||||
y=y,
|
||||
group=None,
|
||||
@ -1851,6 +1858,8 @@ class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
|
||||
params = self.get_xgb_params()
|
||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||
self.client,
|
||||
tree_method=self.tree_method,
|
||||
max_bin=self.max_bin,
|
||||
X=X,
|
||||
y=y,
|
||||
group=None,
|
||||
@ -2057,6 +2066,8 @@ class DaskXGBRanker(DaskScikitLearnBase, XGBRankerMixIn):
|
||||
params = self.get_xgb_params()
|
||||
dtrain, evals = await _async_wrap_evaluation_matrices(
|
||||
self.client,
|
||||
tree_method=self.tree_method,
|
||||
max_bin=self.max_bin,
|
||||
X=X,
|
||||
y=y,
|
||||
group=None,
|
||||
|
||||
@ -38,6 +38,7 @@ from .core import (
|
||||
Booster,
|
||||
DMatrix,
|
||||
Metric,
|
||||
QuantileDMatrix,
|
||||
XGBoostError,
|
||||
_convert_ntree_limit,
|
||||
_deprecate_positional_args,
|
||||
@ -430,7 +431,8 @@ def _wrap_evaluation_matrices(
|
||||
enable_categorical: bool,
|
||||
feature_types: Optional[FeatureTypes],
|
||||
) -> Tuple[Any, List[Tuple[Any, str]]]:
|
||||
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the way."""
|
||||
"""Convert array_like evaluation matrices into DMatrix. Perform validation on the
|
||||
way."""
|
||||
train_dmatrix = create_dmatrix(
|
||||
data=X,
|
||||
label=y,
|
||||
@ -442,6 +444,7 @@ def _wrap_evaluation_matrices(
|
||||
missing=missing,
|
||||
enable_categorical=enable_categorical,
|
||||
feature_types=feature_types,
|
||||
ref=None,
|
||||
)
|
||||
|
||||
n_validation = 0 if eval_set is None else len(eval_set)
|
||||
@ -491,6 +494,7 @@ def _wrap_evaluation_matrices(
|
||||
missing=missing,
|
||||
enable_categorical=enable_categorical,
|
||||
feature_types=feature_types,
|
||||
ref=train_dmatrix,
|
||||
)
|
||||
evals.append(m)
|
||||
nevals = len(evals)
|
||||
@ -904,6 +908,17 @@ class XGBModel(XGBModelBase):
|
||||
|
||||
return model, metric, params, early_stopping_rounds, callbacks
|
||||
|
||||
def _create_dmatrix(self, ref: Optional[DMatrix], **kwargs: Any) -> DMatrix:
|
||||
# Use `QuantileDMatrix` to save memory.
|
||||
if self.tree_method in ("hist", "gpu_hist"):
|
||||
try:
|
||||
return QuantileDMatrix(
|
||||
**kwargs, ref=ref, nthread=self.n_jobs, max_bin=self.max_bin
|
||||
)
|
||||
except TypeError: # `QuantileDMatrix` supports lesser types than DMatrix
|
||||
pass
|
||||
return DMatrix(**kwargs, nthread=self.n_jobs)
|
||||
|
||||
def _set_evaluation_result(self, evals_result: TrainingCallback.EvalsLog) -> None:
|
||||
if evals_result:
|
||||
self.evals_result_ = cast(Dict[str, Dict[str, List[float]]], evals_result)
|
||||
@ -996,7 +1011,7 @@ class XGBModel(XGBModelBase):
|
||||
base_margin_eval_set=base_margin_eval_set,
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
create_dmatrix=self._create_dmatrix,
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
@ -1479,7 +1494,7 @@ class XGBClassifier(XGBModel, XGBClassifierBase):
|
||||
base_margin_eval_set=base_margin_eval_set,
|
||||
eval_group=None,
|
||||
eval_qid=None,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
create_dmatrix=self._create_dmatrix,
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
@ -1930,7 +1945,7 @@ class XGBRanker(XGBModel, XGBRankerMixIn):
|
||||
base_margin_eval_set=base_margin_eval_set,
|
||||
eval_group=eval_group,
|
||||
eval_qid=eval_qid,
|
||||
create_dmatrix=lambda **kwargs: DMatrix(nthread=self.n_jobs, **kwargs),
|
||||
create_dmatrix=self._create_dmatrix,
|
||||
enable_categorical=self.enable_categorical,
|
||||
feature_types=self.feature_types,
|
||||
)
|
||||
|
||||
@ -7,6 +7,7 @@
|
||||
|
||||
#include "../common/column_matrix.h"
|
||||
#include "../common/hist_util.h"
|
||||
#include "../tree/param.h" // FIXME(jiamingy): Find a better way to share this parameter.
|
||||
#include "gradient_index.h"
|
||||
#include "proxy_dmatrix.h"
|
||||
#include "simple_batch_iterator.h"
|
||||
@ -14,6 +15,38 @@
|
||||
namespace xgboost {
|
||||
namespace data {
|
||||
|
||||
IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback* reset,
|
||||
XGDMatrixCallbackNext* next, float missing, int nthread,
|
||||
bst_bin_t max_bin)
|
||||
: proxy_{proxy}, reset_{reset}, next_{next} {
|
||||
// fetch the first batch
|
||||
auto iter =
|
||||
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
|
||||
iter.Reset();
|
||||
bool valid = iter.Next();
|
||||
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";
|
||||
|
||||
auto d = MakeProxy(proxy_)->DeviceIdx();
|
||||
|
||||
StringView msg{"All batch should be on the same device."};
|
||||
if (batch_param_.gpu_id != Context::kCpuId) {
|
||||
CHECK_EQ(d, batch_param_.gpu_id) << msg;
|
||||
}
|
||||
|
||||
batch_param_ = BatchParam{d, max_bin};
|
||||
// hardcoded parameter.
|
||||
batch_param_.sparse_thresh = tree::TrainParam::DftSparseThreshold();
|
||||
|
||||
ctx_.UpdateAllowUnknown(
|
||||
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||
if (ctx_.IsCPU()) {
|
||||
this->InitFromCPU(iter_handle, missing, ref);
|
||||
} else {
|
||||
this->InitFromCUDA(iter_handle, missing, ref);
|
||||
}
|
||||
}
|
||||
|
||||
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
|
||||
common::HistogramCuts* p_cuts) {
|
||||
CHECK(ref_);
|
||||
@ -199,6 +232,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
|
||||
if (n_batches == 1) {
|
||||
this->info_ = std::move(proxy->Info());
|
||||
this->info_.num_nonzero_ = nnz;
|
||||
this->info_.num_col_ = n_features; // proxy might be empty.
|
||||
CHECK_EQ(proxy->Info().labels.Size(), 0);
|
||||
}
|
||||
}
|
||||
@ -210,6 +244,10 @@ BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const&
|
||||
ghist_ = std::make_shared<GHistIndexMatrix>(&ctx_, Info(), *ellpack_, param);
|
||||
}
|
||||
|
||||
if (param.sparse_thresh != tree::TrainParam::DftSparseThreshold()) {
|
||||
LOG(WARNING) << "`sparse_threshold` can not be changed when `QuantileDMatrix` is used instead "
|
||||
"of `DMatrix`.";
|
||||
}
|
||||
auto begin_iter =
|
||||
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
|
||||
return BatchSet<GHistIndexMatrix>(begin_iter);
|
||||
|
||||
@ -173,8 +173,15 @@ BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(BatchParam const& para
|
||||
}
|
||||
if (!ellpack_ && ghist_) {
|
||||
ellpack_.reset(new EllpackPage());
|
||||
this->ctx_.gpu_id = param.gpu_id;
|
||||
this->Info().feature_types.SetDevice(param.gpu_id);
|
||||
// Evaluation QuantileDMatrix initialized from CPU data might not have the correct GPU
|
||||
// ID.
|
||||
if (this->ctx_.IsCPU()) {
|
||||
this->ctx_.gpu_id = param.gpu_id;
|
||||
}
|
||||
if (this->ctx_.IsCPU()) {
|
||||
this->ctx_.gpu_id = dh::CurrentDevice();
|
||||
}
|
||||
this->Info().feature_types.SetDevice(this->ctx_.gpu_id);
|
||||
*ellpack_->Impl() =
|
||||
EllpackPageImpl(&ctx_, *this->ghist_, this->Info().feature_types.ConstDeviceSpan());
|
||||
}
|
||||
|
||||
@ -75,30 +75,7 @@ class IterativeDMatrix : public DMatrix {
|
||||
explicit IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
|
||||
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
|
||||
XGDMatrixCallbackNext *next, float missing, int nthread,
|
||||
bst_bin_t max_bin)
|
||||
: proxy_{proxy}, reset_{reset}, next_{next} {
|
||||
// fetch the first batch
|
||||
auto iter =
|
||||
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
|
||||
iter.Reset();
|
||||
bool valid = iter.Next();
|
||||
CHECK(valid) << "Iterative DMatrix must have at least 1 batch.";
|
||||
|
||||
auto d = MakeProxy(proxy_)->DeviceIdx();
|
||||
if (batch_param_.gpu_id != Context::kCpuId) {
|
||||
CHECK_EQ(d, batch_param_.gpu_id) << "All batch should be on the same device.";
|
||||
}
|
||||
batch_param_ = BatchParam{d, max_bin};
|
||||
batch_param_.sparse_thresh = 0.2; // default from TrainParam
|
||||
|
||||
ctx_.UpdateAllowUnknown(
|
||||
Args{{"nthread", std::to_string(nthread)}, {"gpu_id", std::to_string(d)}});
|
||||
if (ctx_.IsCPU()) {
|
||||
this->InitFromCPU(iter_handle, missing, ref);
|
||||
} else {
|
||||
this->InitFromCUDA(iter_handle, missing, ref);
|
||||
}
|
||||
}
|
||||
bst_bin_t max_bin);
|
||||
~IterativeDMatrix() override = default;
|
||||
|
||||
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }
|
||||
|
||||
@ -78,7 +78,9 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
// ------ From CPU quantile histogram -------.
|
||||
// percentage threshold for treating a feature as sparse
|
||||
// e.g. 0.2 indicates a feature with fewer than 20% nonzeros is considered sparse
|
||||
double sparse_threshold;
|
||||
static constexpr double DftSparseThreshold() { return 0.2; }
|
||||
|
||||
double sparse_threshold{DftSparseThreshold()};
|
||||
|
||||
// declare the parameters
|
||||
DMLC_DECLARE_PARAMETER(TrainParam) {
|
||||
@ -182,7 +184,9 @@ struct TrainParam : public XGBoostParameter<TrainParam> {
|
||||
"See tutorial for more information");
|
||||
|
||||
// ------ From cpu quantile histogram -------.
|
||||
DMLC_DECLARE_FIELD(sparse_threshold).set_range(0, 1.0).set_default(0.2)
|
||||
DMLC_DECLARE_FIELD(sparse_threshold)
|
||||
.set_range(0, 1.0)
|
||||
.set_default(DftSparseThreshold())
|
||||
.describe("percentage threshold for treating a feature as sparse");
|
||||
|
||||
// add alias of parameters
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user