Use Predictor for dart. (#6693)
* Use normal predictor for dart booster. * Implement `inplace_predict` for dart. * Enable `dart` for dask interface now that it's thread-safe. * categorical data should be working out of box for dart now. The implementation is not very efficient as it has to pull back the data and apply weight for each tree, but still a significant improvement over previous implementation as now we no longer binary search for each sample. * Fix output prediction shape on dataframe.
This commit is contained in:
parent
dbf7e9d3cb
commit
e8c5c53e2f
@ -119,6 +119,17 @@ class Predictor {
|
|||||||
*/
|
*/
|
||||||
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
|
virtual void Configure(const std::vector<std::pair<std::string, std::string>>&);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Initialize output prediction
|
||||||
|
*
|
||||||
|
* \param info Meta info for the DMatrix object used for prediction.
|
||||||
|
* \param out_predt Prediction vector to be initialized.
|
||||||
|
* \param model Tree model used for prediction.
|
||||||
|
*/
|
||||||
|
virtual void InitOutPredictions(const MetaInfo &info,
|
||||||
|
HostDeviceVector<bst_float> *out_predt,
|
||||||
|
const gbm::GBTreeModel &model) const = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* \brief Generate batch predictions for a given feature matrix. May use
|
* \brief Generate batch predictions for a given feature matrix. May use
|
||||||
* cached predictions if available instead of calculating from scratch.
|
* cached predictions if available instead of calculating from scratch.
|
||||||
|
|||||||
@ -804,7 +804,7 @@ async def _train_async(
|
|||||||
workers = list(_get_workers_from_data(dtrain, evals))
|
workers = list(_get_workers_from_data(dtrain, evals))
|
||||||
_rabit_args = await _get_rabit_args(len(workers), client)
|
_rabit_args = await _get_rabit_args(len(workers), client)
|
||||||
|
|
||||||
if params.get("booster", None) is not None and params["booster"] != "gbtree":
|
if params.get("booster", None) == "gblinear":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"booster `{params['booster']}` is not yet supported for dask."
|
f"booster `{params['booster']}` is not yet supported for dask."
|
||||||
)
|
)
|
||||||
@ -949,6 +949,15 @@ async def _direct_predict_impl(
|
|||||||
meta: Dict[int, str],
|
meta: Dict[int, str],
|
||||||
) -> _DaskCollection:
|
) -> _DaskCollection:
|
||||||
columns = list(meta.keys())
|
columns = list(meta.keys())
|
||||||
|
if len(output_shape) >= 3 and isinstance(data, dd.DataFrame):
|
||||||
|
# Without this check, dask will finish the prediction silently even if output
|
||||||
|
# dimension is greater than 3. But during map_partitions, dask passes a
|
||||||
|
# `dd.DataFrame` as local input to xgboost, which is converted to csr_matrix by
|
||||||
|
# `_convert_unknown_data` since dd.DataFrame is not known to xgboost native
|
||||||
|
# binding.
|
||||||
|
raise ValueError(
|
||||||
|
"Use `da.Array` or `DaskDMatrix` when output has more than 2 dimensions."
|
||||||
|
)
|
||||||
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
|
if _can_output_df(isinstance(data, dd.DataFrame), output_shape):
|
||||||
if base_margin is not None and isinstance(base_margin, da.Array):
|
if base_margin is not None and isinstance(base_margin, da.Array):
|
||||||
# Easier for map_partitions
|
# Easier for map_partitions
|
||||||
@ -1012,6 +1021,7 @@ def _infer_predict_output(
|
|||||||
if kwargs.pop("predict_type") == "margin":
|
if kwargs.pop("predict_type") == "margin":
|
||||||
kwargs["output_margin"] = True
|
kwargs["output_margin"] = True
|
||||||
m = DMatrix(test_sample)
|
m = DMatrix(test_sample)
|
||||||
|
# generated DMatrix doesn't have feature name, so no validation.
|
||||||
test_predt = booster.predict(m, validate_features=False, **kwargs)
|
test_predt = booster.predict(m, validate_features=False, **kwargs)
|
||||||
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
n_columns = test_predt.shape[1] if len(test_predt.shape) > 1 else 1
|
||||||
meta: Dict[int, str] = {}
|
meta: Dict[int, str] = {}
|
||||||
@ -1098,6 +1108,7 @@ async def _predict_async(
|
|||||||
pred_contribs=pred_contribs,
|
pred_contribs=pred_contribs,
|
||||||
approx_contribs=approx_contribs,
|
approx_contribs=approx_contribs,
|
||||||
pred_interactions=pred_interactions,
|
pred_interactions=pred_interactions,
|
||||||
|
strict_shape=strict_shape,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return await _direct_predict_impl(
|
return await _direct_predict_impl(
|
||||||
@ -1116,6 +1127,7 @@ async def _predict_async(
|
|||||||
pred_contribs=pred_contribs,
|
pred_contribs=pred_contribs,
|
||||||
approx_contribs=approx_contribs,
|
approx_contribs=approx_contribs,
|
||||||
pred_interactions=pred_interactions,
|
pred_interactions=pred_interactions,
|
||||||
|
strict_shape=strict_shape,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Prediction on dask DMatrix.
|
# Prediction on dask DMatrix.
|
||||||
@ -1206,10 +1218,9 @@ def predict( # pylint: disable=unused-argument
|
|||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Using ``inplace_predict`` might be faster when some features are not needed. See
|
Using ``inplace_predict`` might be faster when some features are not needed. See
|
||||||
:py:meth:`xgboost.Booster.predict` for details on various parameters. When using
|
:py:meth:`xgboost.Booster.predict` for details on various parameters. When output
|
||||||
``pred_interactions`` with mutli-class model, input should be ``da.Array`` or
|
has more than 2 dimensions (shap value, leaf with strict_shape), input should be
|
||||||
``DaskDMatrix`` due to limitation in ``da.map_blocks``.
|
``da.Array`` or ``DaskDMatrix``.
|
||||||
|
|
||||||
|
|
||||||
.. versionadded:: 1.0.0
|
.. versionadded:: 1.0.0
|
||||||
|
|
||||||
@ -1233,8 +1244,8 @@ def predict( # pylint: disable=unused-argument
|
|||||||
prediction: dask.array.Array/dask.dataframe.Series
|
prediction: dask.array.Array/dask.dataframe.Series
|
||||||
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
When input data is ``dask.array.Array`` or ``DaskDMatrix``, the return value is an
|
||||||
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
array, when input data is ``dask.dataframe.DataFrame``, return value can be
|
||||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
||||||
depending on the output shape.
|
shape.
|
||||||
|
|
||||||
'''
|
'''
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
@ -1297,6 +1308,7 @@ async def _inplace_predict_async( # pylint: disable=too-many-branches
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
predict_type=predict_type,
|
predict_type=predict_type,
|
||||||
iteration_range=iteration_range,
|
iteration_range=iteration_range,
|
||||||
|
strict_shape=strict_shape,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return await _direct_predict_impl(
|
return await _direct_predict_impl(
|
||||||
@ -1352,8 +1364,9 @@ def inplace_predict( # pylint: disable=unused-argument
|
|||||||
prediction :
|
prediction :
|
||||||
When input data is ``dask.array.Array``, the return value is an array, when input
|
When input data is ``dask.array.Array``, the return value is an array, when input
|
||||||
data is ``dask.dataframe.DataFrame``, return value can be
|
data is ``dask.dataframe.DataFrame``, return value can be
|
||||||
``dask.dataframe.Series``, ``dask.dataframe.DataFrame`` or ``dask.array.Array``,
|
``dask.dataframe.Series``, ``dask.dataframe.DataFrame``, depending on the output
|
||||||
depending on the output shape.
|
shape.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
_assert_dask_support()
|
_assert_dask_support()
|
||||||
client = _xgb_get_client(client)
|
client = _xgb_get_client(client)
|
||||||
|
|||||||
@ -754,10 +754,7 @@ class XGBModel(XGBModelBase):
|
|||||||
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
# Inplace predict doesn't handle as many data types as DMatrix, but it's
|
||||||
# sufficient for dask interface where input is simpiler.
|
# sufficient for dask interface where input is simpiler.
|
||||||
params = self.get_params()
|
params = self.get_params()
|
||||||
booster = self.booster
|
if params.get("predictor", None) is None and self.booster != "gblinear":
|
||||||
if params.get("predictor", None) is None and (
|
|
||||||
booster is None or booster == "gbtree"
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -455,12 +455,22 @@ void GBTree::PredictBatch(DMatrix* p_fmat,
|
|||||||
// When begin layer is not 0, the cache is not useful.
|
// When begin layer is not 0, the cache is not useful.
|
||||||
reset = true;
|
reset = true;
|
||||||
}
|
}
|
||||||
|
if (out_preds->predictions.Size() == 0 && p_fmat->Info().num_row_ != 0) {
|
||||||
|
CHECK_EQ(out_preds->version, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const& predictor = GetPredictor(&out_preds->predictions, p_fmat);
|
||||||
|
if (out_preds->version == 0) {
|
||||||
|
// out_preds->Size() can be non-zero as it's initialized here before any
|
||||||
|
// tree is built at the 0^th iterator.
|
||||||
|
predictor->InitOutPredictions(p_fmat->Info(), &out_preds->predictions,
|
||||||
|
model_);
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t tree_begin, tree_end;
|
uint32_t tree_begin, tree_end;
|
||||||
std::tie(tree_begin, tree_end) =
|
std::tie(tree_begin, tree_end) =
|
||||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||||
GetPredictor(&out_preds->predictions, p_fmat)
|
predictor->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end);
|
||||||
->PredictBatch(p_fmat, out_preds, model_, tree_begin, tree_end);
|
|
||||||
if (reset) {
|
if (reset) {
|
||||||
out_preds->version = 0;
|
out_preds->version = 0;
|
||||||
} else {
|
} else {
|
||||||
@ -625,54 +635,124 @@ class Dart : public GBTree {
|
|||||||
out["dart_train_param"] = ToJson(dparam_);
|
out["dart_train_param"] = ToJson(dparam_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// An independent const function to make sure it's thread safe.
|
||||||
|
void PredictBatchImpl(DMatrix *p_fmat, PredictionCacheEntry *p_out_preds,
|
||||||
|
bool training, unsigned layer_begin,
|
||||||
|
unsigned layer_end) const {
|
||||||
|
auto &predictor = this->GetPredictor(&p_out_preds->predictions, p_fmat);
|
||||||
|
CHECK(predictor);
|
||||||
|
predictor->InitOutPredictions(p_fmat->Info(), &p_out_preds->predictions,
|
||||||
|
model_);
|
||||||
|
p_out_preds->version = 0;
|
||||||
|
uint32_t tree_begin, tree_end;
|
||||||
|
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||||
|
for (size_t i = tree_begin; i < tree_end; i += 1) {
|
||||||
|
if (training &&
|
||||||
|
std::binary_search(idx_drop_.cbegin(), idx_drop_.cend(), i)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
CHECK_GE(i, p_out_preds->version);
|
||||||
|
auto version = i / this->LayerTrees();
|
||||||
|
p_out_preds->version = version;
|
||||||
|
|
||||||
|
auto n_groups = model_.learner_model_param->num_output_group;
|
||||||
|
PredictionCacheEntry predts;
|
||||||
|
predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0);
|
||||||
|
predictor->PredictBatch(p_fmat, &predts, model_, i, i + 1);
|
||||||
|
|
||||||
|
// Multiple the weight to output prediction.
|
||||||
|
auto w = this->weight_drop_.at(i);
|
||||||
|
auto &h_predts = predts.predictions.HostVector();
|
||||||
|
auto group = model_.tree_info.at(i);
|
||||||
|
auto &h_out_predts = p_out_preds->predictions.HostVector();
|
||||||
|
CHECK_EQ(h_out_predts.size(), h_predts.size());
|
||||||
|
for (size_t ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) {
|
||||||
|
const size_t offset = ridx * n_groups + group;
|
||||||
|
h_out_predts[offset] += (h_predts[offset] * w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void PredictBatch(DMatrix* p_fmat,
|
void PredictBatch(DMatrix* p_fmat,
|
||||||
PredictionCacheEntry* p_out_preds,
|
PredictionCacheEntry* p_out_preds,
|
||||||
bool training,
|
bool training,
|
||||||
unsigned layer_begin,
|
unsigned layer_begin,
|
||||||
unsigned layer_end) override {
|
unsigned layer_end) override {
|
||||||
DropTrees(training);
|
DropTrees(training);
|
||||||
int num_group = model_.learner_model_param->num_output_group;
|
this->PredictBatchImpl(p_fmat, p_out_preds, training, layer_begin, layer_end);
|
||||||
uint32_t tree_begin, tree_end;
|
}
|
||||||
std::tie(tree_begin, tree_end) =
|
|
||||||
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
|
||||||
|
|
||||||
size_t n = num_group * p_fmat->Info().num_row_;
|
void InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||||
const auto &base_margin = p_fmat->Info().base_margin_.ConstHostVector();
|
float missing, PredictionCacheEntry *out_preds,
|
||||||
auto& out_preds = p_out_preds->predictions.HostVector();
|
uint32_t layer_begin, unsigned layer_end) const override {
|
||||||
out_preds.resize(n);
|
uint32_t tree_begin, tree_end;
|
||||||
if (base_margin.size() != 0) {
|
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||||
CHECK_EQ(out_preds.size(), n);
|
std::vector<Predictor const *> predictors{
|
||||||
std::copy(base_margin.begin(), base_margin.end(), out_preds.begin());
|
cpu_predictor_.get(),
|
||||||
} else {
|
#if defined(XGBOOST_USE_CUDA)
|
||||||
std::fill(out_preds.begin(), out_preds.end(),
|
gpu_predictor_.get()
|
||||||
model_.learner_model_param->base_score);
|
#endif // defined(XGBOOST_USE_CUDA)
|
||||||
|
};
|
||||||
|
|
||||||
|
MetaInfo info;
|
||||||
|
StringView msg{"Unsupported data type for inplace predict."};
|
||||||
|
// Inplace predict is not used for training, so no need to drop tree.
|
||||||
|
for (size_t i = tree_begin; i < tree_end; ++i) {
|
||||||
|
PredictionCacheEntry predts;
|
||||||
|
if (tparam_.predictor == PredictorType::kAuto) {
|
||||||
|
// Try both predictor implementations
|
||||||
|
bool success = false;
|
||||||
|
for (auto const &p : predictors) {
|
||||||
|
if (p && p->InplacePredict(x, nullptr, model_, missing, &predts, i,
|
||||||
|
i + 1)) {
|
||||||
|
success = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
CHECK(success) << msg;
|
||||||
|
} else {
|
||||||
|
// No base margin for each tree
|
||||||
|
bool success = this->GetPredictor()->InplacePredict(
|
||||||
|
x, nullptr, model_, missing, &predts, tree_begin, tree_end);
|
||||||
|
CHECK(success) << msg;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto w = this->weight_drop_.at(i);
|
||||||
|
auto &h_predts = predts.predictions.HostVector();
|
||||||
|
auto &h_out_predts = out_preds->predictions.HostVector();
|
||||||
|
if (h_out_predts.empty()) {
|
||||||
|
auto n_rows =
|
||||||
|
h_predts.size() / model_.learner_model_param->num_output_group;
|
||||||
|
if (p_m) {
|
||||||
|
p_m->Info().num_row_ = n_rows;
|
||||||
|
cpu_predictor_->InitOutPredictions(p_m->Info(),
|
||||||
|
&out_preds->predictions, model_);
|
||||||
|
} else {
|
||||||
|
info.num_row_ = n_rows;
|
||||||
|
cpu_predictor_->InitOutPredictions(info, &out_preds->predictions,
|
||||||
|
model_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple the tree weight
|
||||||
|
CHECK_EQ(h_predts.size(), h_out_predts.size());
|
||||||
|
for (size_t i = 0; i < h_out_predts.size(); ++i) {
|
||||||
|
// Need to remove the base margin from indiviual tree.
|
||||||
|
h_out_predts[i] +=
|
||||||
|
(h_predts[i] - model_.learner_model_param->base_score) * w;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
const int nthread = omp_get_max_threads();
|
|
||||||
InitThreadTemp(nthread);
|
|
||||||
PredLoopSpecalize(p_fmat, &out_preds, num_group, tree_begin, tree_end);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictInstance(const SparsePage::Inst &inst,
|
void PredictInstance(const SparsePage::Inst &inst,
|
||||||
std::vector<bst_float> *out_preds,
|
std::vector<bst_float> *out_preds,
|
||||||
unsigned layer_begin, unsigned layer_end) override {
|
unsigned layer_begin, unsigned layer_end) override {
|
||||||
DropTrees(false);
|
DropTrees(false);
|
||||||
if (thread_temp_.size() == 0) {
|
auto &predictor = this->GetPredictor();
|
||||||
thread_temp_.resize(1, RegTree::FVec());
|
uint32_t _, tree_end;
|
||||||
thread_temp_[0].Init(model_.learner_model_param->num_feature);
|
std::tie(_, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
||||||
}
|
predictor->PredictInstance(inst, out_preds, model_, tree_end);
|
||||||
out_preds->resize(model_.learner_model_param->num_output_group);
|
|
||||||
uint32_t tree_begin, tree_end;
|
|
||||||
std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
|
|
||||||
// loop over output groups
|
|
||||||
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
|
|
||||||
(*out_preds)[gid] =
|
|
||||||
PredValue(inst, gid, &thread_temp_[0], 0, tree_end) +
|
|
||||||
model_.learner_model_param->base_score;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool UseGPU() const override {
|
|
||||||
return GBTree::UseGPU();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void PredictContribution(DMatrix* p_fmat,
|
void PredictContribution(DMatrix* p_fmat,
|
||||||
@ -697,60 +777,6 @@ class Dart : public GBTree {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
inline void PredLoopSpecalize(
|
|
||||||
DMatrix* p_fmat,
|
|
||||||
std::vector<bst_float>* out_preds,
|
|
||||||
int num_group,
|
|
||||||
unsigned tree_begin,
|
|
||||||
unsigned tree_end) {
|
|
||||||
CHECK_EQ(num_group, model_.learner_model_param->num_output_group);
|
|
||||||
std::vector<bst_float>& preds = *out_preds;
|
|
||||||
CHECK_EQ(model_.param.size_leaf_vector, 0)
|
|
||||||
<< "size_leaf_vector is enforced to 0 so far";
|
|
||||||
CHECK_EQ(preds.size(), p_fmat->Info().num_row_ * num_group);
|
|
||||||
// start collecting the prediction
|
|
||||||
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
|
|
||||||
auto page = batch.GetView();
|
|
||||||
constexpr int kUnroll = 8;
|
|
||||||
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
|
|
||||||
const bst_omp_uint rest = nsize % kUnroll;
|
|
||||||
if (nsize >= kUnroll) {
|
|
||||||
#pragma omp parallel for schedule(static)
|
|
||||||
for (bst_omp_uint i = 0; i < nsize - rest; i += kUnroll) {
|
|
||||||
const int tid = omp_get_thread_num();
|
|
||||||
RegTree::FVec& feats = thread_temp_[tid];
|
|
||||||
int64_t ridx[kUnroll];
|
|
||||||
SparsePage::Inst inst[kUnroll];
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
ridx[k] = static_cast<int64_t>(batch.base_rowid + i + k);
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
inst[k] = page[i + k];
|
|
||||||
}
|
|
||||||
for (int k = 0; k < kUnroll; ++k) {
|
|
||||||
for (int gid = 0; gid < num_group; ++gid) {
|
|
||||||
const size_t offset = ridx[k] * num_group + gid;
|
|
||||||
preds[offset] +=
|
|
||||||
this->PredValue(inst[k], gid, &feats, tree_begin, tree_end);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (bst_omp_uint i = nsize - rest; i < nsize; ++i) {
|
|
||||||
RegTree::FVec& feats = thread_temp_[0];
|
|
||||||
const auto ridx = static_cast<int64_t>(batch.base_rowid + i);
|
|
||||||
const SparsePage::Inst inst = page[i];
|
|
||||||
for (int gid = 0; gid < num_group; ++gid) {
|
|
||||||
const size_t offset = ridx * num_group + gid;
|
|
||||||
preds[offset] +=
|
|
||||||
this->PredValue(inst, gid,
|
|
||||||
&feats, tree_begin, tree_end);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// commit new trees all at once
|
// commit new trees all at once
|
||||||
void
|
void
|
||||||
CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
|
CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& new_trees,
|
||||||
@ -765,32 +791,13 @@ class Dart : public GBTree {
|
|||||||
<< "weight = " << weight_drop_.back();
|
<< "weight = " << weight_drop_.back();
|
||||||
}
|
}
|
||||||
|
|
||||||
// predict the leaf scores without dropped trees
|
// Select which trees to drop.
|
||||||
bst_float PredValue(const SparsePage::Inst &inst, int bst_group,
|
|
||||||
RegTree::FVec *p_feats, unsigned tree_begin,
|
|
||||||
unsigned tree_end) const {
|
|
||||||
bst_float psum = 0.0f;
|
|
||||||
p_feats->Fill(inst);
|
|
||||||
for (size_t i = tree_begin; i < tree_end; ++i) {
|
|
||||||
if (model_.tree_info[i] == bst_group) {
|
|
||||||
bool drop = std::binary_search(idx_drop_.begin(), idx_drop_.end(), i);
|
|
||||||
if (!drop) {
|
|
||||||
int tid = model_.trees[i]->GetLeafIndex(*p_feats);
|
|
||||||
psum += weight_drop_[i] * (*model_.trees[i])[tid].LeafValue();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
p_feats->Drop(inst);
|
|
||||||
return psum;
|
|
||||||
}
|
|
||||||
|
|
||||||
// select which trees to drop
|
|
||||||
// passing clear=True will clear selection
|
|
||||||
inline void DropTrees(bool is_training) {
|
inline void DropTrees(bool is_training) {
|
||||||
idx_drop_.clear();
|
|
||||||
if (!is_training) {
|
if (!is_training) {
|
||||||
|
// This function should be thread safe when it's not training.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
idx_drop_.clear();
|
||||||
|
|
||||||
std::uniform_real_distribution<> runif(0.0, 1.0);
|
std::uniform_real_distribution<> runif(0.0, 1.0);
|
||||||
auto& rnd = common::GlobalRandom();
|
auto& rnd = common::GlobalRandom();
|
||||||
|
|||||||
@ -201,7 +201,7 @@ class CPUPredictor : public Predictor {
|
|||||||
|
|
||||||
void InitOutPredictions(const MetaInfo& info,
|
void InitOutPredictions(const MetaInfo& info,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model) const {
|
const gbm::GBTreeModel& model) const override {
|
||||||
CHECK_NE(model.learner_model_param->num_output_group, 0);
|
CHECK_NE(model.learner_model_param->num_output_group, 0);
|
||||||
size_t n = model.learner_model_param->num_output_group * info.num_row_;
|
size_t n = model.learner_model_param->num_output_group * info.num_row_;
|
||||||
const auto& base_margin = info.base_margin_.HostVector();
|
const auto& base_margin = info.base_margin_.HostVector();
|
||||||
@ -234,26 +234,16 @@ class CPUPredictor : public Predictor {
|
|||||||
public:
|
public:
|
||||||
explicit CPUPredictor(GenericParameter const* generic_param) :
|
explicit CPUPredictor(GenericParameter const* generic_param) :
|
||||||
Predictor::Predictor{generic_param} {}
|
Predictor::Predictor{generic_param} {}
|
||||||
|
|
||||||
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
|
void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
|
||||||
const gbm::GBTreeModel &model, uint32_t tree_begin,
|
const gbm::GBTreeModel &model, uint32_t tree_begin,
|
||||||
uint32_t tree_end = 0) const override {
|
uint32_t tree_end = 0) const override {
|
||||||
auto* out_preds = &predts->predictions;
|
auto* out_preds = &predts->predictions;
|
||||||
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
|
||||||
CHECK_EQ(predts->version, 0);
|
|
||||||
}
|
|
||||||
// This is actually already handled in gbm, but large amount of tests rely on the
|
// This is actually already handled in gbm, but large amount of tests rely on the
|
||||||
// behaviour.
|
// behaviour.
|
||||||
if (tree_end == 0) {
|
if (tree_end == 0) {
|
||||||
tree_end = model.trees.size();
|
tree_end = model.trees.size();
|
||||||
}
|
}
|
||||||
if (predts->version == 0) {
|
|
||||||
// out_preds->Size() can be non-zero as it's initialized here before any tree is
|
|
||||||
// built at the 0^th iterator.
|
|
||||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
|
||||||
}
|
|
||||||
if (tree_end - tree_begin == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
|
this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin,
|
||||||
tree_end);
|
tree_end);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -599,21 +599,9 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
int device = generic_param_->gpu_id;
|
int device = generic_param_->gpu_id;
|
||||||
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
|
||||||
auto* out_preds = &predts->predictions;
|
auto* out_preds = &predts->predictions;
|
||||||
|
|
||||||
if (out_preds->Size() == 0 && dmat->Info().num_row_ != 0) {
|
|
||||||
CHECK_EQ(predts->version, 0);
|
|
||||||
}
|
|
||||||
if (tree_end == 0) {
|
if (tree_end == 0) {
|
||||||
tree_end = model.trees.size();
|
tree_end = model.trees.size();
|
||||||
}
|
}
|
||||||
if (predts->version == 0) {
|
|
||||||
// out_preds->Size() can be non-zero as it's initialized here before any tree is
|
|
||||||
// built at the 0^th iterator.
|
|
||||||
this->InitOutPredictions(dmat->Info(), out_preds, model);
|
|
||||||
}
|
|
||||||
if (tree_end - tree_begin == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
|
this->DevicePredictInternal(dmat, out_preds, model, tree_begin, tree_end);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -788,7 +776,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
protected:
|
protected:
|
||||||
void InitOutPredictions(const MetaInfo& info,
|
void InitOutPredictions(const MetaInfo& info,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model) const {
|
const gbm::GBTreeModel& model) const override {
|
||||||
size_t n_classes = model.learner_model_param->num_output_group;
|
size_t n_classes = model.learner_model_param->num_output_group;
|
||||||
size_t n = n_classes * info.num_row_;
|
size_t n = n_classes * info.num_row_;
|
||||||
const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
|
const HostDeviceVector<bst_float>& base_margin = info.base_margin_;
|
||||||
|
|||||||
@ -10,6 +10,7 @@
|
|||||||
#include "xgboost/learner.h"
|
#include "xgboost/learner.h"
|
||||||
#include "../helpers.h"
|
#include "../helpers.h"
|
||||||
#include "../../../src/gbm/gbtree.h"
|
#include "../../../src/gbm/gbtree.h"
|
||||||
|
#include "../../../src/data/adapter.h"
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
@ -247,7 +248,9 @@ TEST(Dart, JsonIO) {
|
|||||||
TEST(Dart, Prediction) {
|
TEST(Dart, Prediction) {
|
||||||
size_t constexpr kRows = 16, kCols = 10;
|
size_t constexpr kRows = 16, kCols = 10;
|
||||||
|
|
||||||
auto p_mat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
|
HostDeviceVector<float> data;
|
||||||
|
auto array_str = RandomDataGenerator(kRows, kCols, 0).GenerateArrayInterface(&data);
|
||||||
|
auto p_mat = GetDMatrixFromData(data.HostVector(), kRows, kCols);
|
||||||
|
|
||||||
std::vector<bst_float> labels (kRows);
|
std::vector<bst_float> labels (kRows);
|
||||||
for (size_t i = 0; i < kRows; ++i) {
|
for (size_t i = 0; i < kRows; ++i) {
|
||||||
@ -265,16 +268,28 @@ TEST(Dart, Prediction) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
HostDeviceVector<float> predts_training;
|
HostDeviceVector<float> predts_training;
|
||||||
learner->Predict(p_mat, false, &predts_training, 0, true);
|
learner->Predict(p_mat, false, &predts_training, 0, 0, true);
|
||||||
HostDeviceVector<float> predts_inference;
|
|
||||||
learner->Predict(p_mat, false, &predts_inference, 0, false);
|
|
||||||
|
|
||||||
auto& h_predts_training = predts_training.ConstHostVector();
|
HostDeviceVector<float>* inplace_predts;
|
||||||
auto& h_predts_inference = predts_inference.ConstHostVector();
|
auto adapter = std::shared_ptr<data::ArrayAdapter>(new data::ArrayAdapter{StringView{array_str}});
|
||||||
|
learner->InplacePredict(adapter, nullptr, PredictionType::kValue,
|
||||||
|
std::numeric_limits<float>::quiet_NaN(),
|
||||||
|
&inplace_predts, 0, 0);
|
||||||
|
CHECK(inplace_predts);
|
||||||
|
|
||||||
|
HostDeviceVector<float> predts_inference;
|
||||||
|
learner->Predict(p_mat, false, &predts_inference, 0, 0, false);
|
||||||
|
|
||||||
|
auto const& h_predts_training = predts_training.ConstHostVector();
|
||||||
|
auto const& h_predts_inference = predts_inference.ConstHostVector();
|
||||||
|
auto const& h_inplace_predts = inplace_predts->HostVector();
|
||||||
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
|
ASSERT_EQ(h_predts_training.size(), h_predts_inference.size());
|
||||||
|
ASSERT_EQ(h_inplace_predts.size(), h_predts_inference.size());
|
||||||
for (size_t i = 0; i < predts_inference.Size(); ++i) {
|
for (size_t i = 0; i < predts_inference.Size(); ++i) {
|
||||||
// Inference doesn't drop tree.
|
// Inference doesn't drop tree.
|
||||||
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps);
|
ASSERT_GT(std::abs(h_predts_training[i] - h_predts_inference[i]), kRtEps * 10);
|
||||||
|
// Inplace prediction is inference.
|
||||||
|
ASSERT_LT(h_inplace_predts[i] - h_predts_inference[i], kRtEps / 10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -31,6 +31,7 @@ TEST(CpuPredictor, Basic) {
|
|||||||
|
|
||||||
// Test predict batch
|
// Test predict batch
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||||
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||||
|
|
||||||
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
std::vector<float>& out_predictions_h = out_predictions.predictions.HostVector();
|
||||||
@ -107,6 +108,7 @@ TEST(CpuPredictor, ExternalMemory) {
|
|||||||
|
|
||||||
// Test predict batch
|
// Test predict batch
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
cpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||||
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
cpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||||
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
|
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
|
||||||
ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_);
|
ASSERT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_);
|
||||||
|
|||||||
@ -44,7 +44,9 @@ TEST(GPUPredictor, Basic) {
|
|||||||
PredictionCacheEntry gpu_out_predictions;
|
PredictionCacheEntry gpu_out_predictions;
|
||||||
PredictionCacheEntry cpu_out_predictions;
|
PredictionCacheEntry cpu_out_predictions;
|
||||||
|
|
||||||
|
gpu_predictor->InitOutPredictions(dmat->Info(), &gpu_out_predictions.predictions, model);
|
||||||
gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0);
|
gpu_predictor->PredictBatch(dmat.get(), &gpu_out_predictions, model, 0);
|
||||||
|
cpu_predictor->InitOutPredictions(dmat->Info(), &cpu_out_predictions.predictions, model);
|
||||||
cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0);
|
cpu_predictor->PredictBatch(dmat.get(), &cpu_out_predictions, model, 0);
|
||||||
|
|
||||||
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector();
|
std::vector<float>& gpu_out_predictions_h = gpu_out_predictions.predictions.HostVector();
|
||||||
@ -111,6 +113,7 @@ TEST(GPUPredictor, ExternalMemoryTest) {
|
|||||||
for (const auto& dmat: dmats) {
|
for (const auto& dmat: dmats) {
|
||||||
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);
|
dmat->Info().base_margin_.Resize(dmat->Info().num_row_ * n_classes, 0.5);
|
||||||
PredictionCacheEntry out_predictions;
|
PredictionCacheEntry out_predictions;
|
||||||
|
gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model);
|
||||||
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0);
|
||||||
EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes);
|
EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes);
|
||||||
const std::vector<float> &host_vector = out_predictions.predictions.ConstHostVector();
|
const std::vector<float> &host_vector = out_predictions.predictions.ConstHostVector();
|
||||||
|
|||||||
@ -218,6 +218,7 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
row[split_ind] = split_cat;
|
row[split_ind] = split_cat;
|
||||||
auto m = GetDMatrixFromData(row, 1, kCols);
|
auto m = GetDMatrixFromData(row, 1, kCols);
|
||||||
|
|
||||||
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||||
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
|
||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||||
@ -226,6 +227,7 @@ void TestCategoricalPrediction(std::string name) {
|
|||||||
row[split_ind] = split_cat + 1;
|
row[split_ind] = split_cat + 1;
|
||||||
m = GetDMatrixFromData(row, 1, kCols);
|
m = GetDMatrixFromData(row, 1, kCols);
|
||||||
out_predictions.version = 0;
|
out_predictions.version = 0;
|
||||||
|
predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
|
||||||
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
|
||||||
left_weight + param.base_score);
|
left_weight + param.base_score);
|
||||||
|
|||||||
@ -29,9 +29,11 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
|||||||
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
auto p_precise = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
||||||
|
|
||||||
PredictionCacheEntry approx_out_predictions;
|
PredictionCacheEntry approx_out_predictions;
|
||||||
|
predictor->InitOutPredictions(p_hist->Info(), &approx_out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(p_hist.get(), &approx_out_predictions, model, 0);
|
predictor->PredictBatch(p_hist.get(), &approx_out_predictions, model, 0);
|
||||||
|
|
||||||
PredictionCacheEntry precise_out_predictions;
|
PredictionCacheEntry precise_out_predictions;
|
||||||
|
predictor->InitOutPredictions(p_precise->Info(), &precise_out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
|
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
|
||||||
|
|
||||||
for (size_t i = 0; i < rows; ++i) {
|
for (size_t i = 0; i < rows; ++i) {
|
||||||
@ -46,6 +48,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
|||||||
// matrix is used for training.
|
// matrix is used for training.
|
||||||
auto p_dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
auto p_dmat = RandomDataGenerator(rows, cols, 0).GenerateDMatrix();
|
||||||
PredictionCacheEntry precise_out_predictions;
|
PredictionCacheEntry precise_out_predictions;
|
||||||
|
predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model);
|
||||||
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
||||||
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -22,6 +22,16 @@ def run_threaded_predict(X, rows, predict_func):
|
|||||||
assert f.result()
|
assert f.result()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_leaf_output(leaf: np.ndarray, num_parallel_tree: int):
|
||||||
|
for i in range(leaf.shape[0]): # n_samples
|
||||||
|
for j in range(leaf.shape[1]): # n_rounds
|
||||||
|
for k in range(leaf.shape[2]): # n_classes
|
||||||
|
tree_group = leaf[i, j, k, :]
|
||||||
|
assert tree_group.shape[0] == num_parallel_tree
|
||||||
|
# No sampling, all trees within forest are the same
|
||||||
|
assert np.all(tree_group == tree_group[0])
|
||||||
|
|
||||||
|
|
||||||
def run_predict_leaf(predictor):
|
def run_predict_leaf(predictor):
|
||||||
rows = 100
|
rows = 100
|
||||||
cols = 4
|
cols = 4
|
||||||
@ -53,13 +63,7 @@ def run_predict_leaf(predictor):
|
|||||||
assert leaf.shape[2] == classes
|
assert leaf.shape[2] == classes
|
||||||
assert leaf.shape[3] == num_parallel_tree
|
assert leaf.shape[3] == num_parallel_tree
|
||||||
|
|
||||||
for i in range(rows):
|
verify_leaf_output(leaf, num_parallel_tree)
|
||||||
for j in range(num_boost_round):
|
|
||||||
for k in range(classes):
|
|
||||||
tree_group = leaf[i, j, k, :]
|
|
||||||
assert tree_group.shape[0] == num_parallel_tree
|
|
||||||
# No sampling, all trees within forest are the same
|
|
||||||
assert np.all(tree_group == tree_group[0])
|
|
||||||
|
|
||||||
ntree_limit = 2
|
ntree_limit = 2
|
||||||
sliced = booster.predict(
|
sliced = booster.predict(
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import hypothesis
|
|||||||
from hypothesis import given, settings, note, HealthCheck
|
from hypothesis import given, settings, note, HealthCheck
|
||||||
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
from test_updaters import hist_parameter_strategy, exact_parameter_strategy
|
||||||
from test_with_sklearn import run_feature_weights, run_data_initialization
|
from test_with_sklearn import run_feature_weights, run_data_initialization
|
||||||
|
from test_predict import verify_leaf_output
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
pytest.skip("Skipping dask tests on Windows", allow_module_level=True)
|
||||||
@ -748,9 +749,9 @@ def test_dask_ranking(client: "Client") -> None:
|
|||||||
d = d.toarray()
|
d = d.toarray()
|
||||||
d[d == 0] = np.nan
|
d[d == 0] = np.nan
|
||||||
d[np.isinf(d)] = 0
|
d[np.isinf(d)] = 0
|
||||||
data.append(da.from_array(d))
|
data.append(dd.from_array(d, chunksize=32))
|
||||||
else:
|
else:
|
||||||
data.append(da.from_array(d))
|
data.append(dd.from_array(d, chunksize=32))
|
||||||
|
|
||||||
(
|
(
|
||||||
x_train,
|
x_train,
|
||||||
@ -782,6 +783,39 @@ def test_dask_ranking(client: "Client") -> None:
|
|||||||
assert rank.best_score > 0.98
|
assert rank.best_score > 0.98
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("booster", ["dart", "gbtree"])
|
||||||
|
def test_dask_predict_leaf(booster: str, client: "Client") -> None:
|
||||||
|
from sklearn.datasets import load_digits
|
||||||
|
|
||||||
|
X_, y_ = load_digits(return_X_y=True)
|
||||||
|
num_parallel_tree = 4
|
||||||
|
X, y = dd.from_array(X_, chunksize=32), dd.from_array(y_, chunksize=32)
|
||||||
|
rounds = 4
|
||||||
|
cls = xgb.dask.DaskXGBClassifier(
|
||||||
|
n_estimators=rounds, num_parallel_tree=num_parallel_tree, booster=booster
|
||||||
|
)
|
||||||
|
cls.client = client
|
||||||
|
cls.fit(X, y)
|
||||||
|
leaf = xgb.dask.predict(
|
||||||
|
client,
|
||||||
|
cls.get_booster(),
|
||||||
|
X.to_dask_array(), # we can't map_blocks on dataframe when output is 4-dim.
|
||||||
|
pred_leaf=True,
|
||||||
|
strict_shape=True,
|
||||||
|
validate_features=False,
|
||||||
|
).compute()
|
||||||
|
|
||||||
|
assert leaf.shape[0] == X_.shape[0]
|
||||||
|
assert leaf.shape[1] == rounds
|
||||||
|
assert leaf.shape[2] == cls.n_classes_
|
||||||
|
assert leaf.shape[3] == num_parallel_tree
|
||||||
|
|
||||||
|
leaf_from_apply = cls.apply(X).reshape(leaf.shape).compute()
|
||||||
|
np.testing.assert_allclose(leaf_from_apply, leaf)
|
||||||
|
|
||||||
|
verify_leaf_output(leaf, num_parallel_tree)
|
||||||
|
|
||||||
|
|
||||||
class TestWithDask:
|
class TestWithDask:
|
||||||
def test_global_config(self, client: "Client") -> None:
|
def test_global_config(self, client: "Client") -> None:
|
||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
@ -1101,15 +1135,16 @@ class TestWithDask:
|
|||||||
assert_shape(shap.shape)
|
assert_shape(shap.shape)
|
||||||
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-5, 1e-5)
|
||||||
|
|
||||||
X = dd.from_dask_array(X).repartition(npartitions=32)
|
if "num_class" not in params.keys():
|
||||||
y = dd.from_dask_array(y).repartition(npartitions=32)
|
X = dd.from_dask_array(X).repartition(npartitions=32)
|
||||||
shap_df = xgb.dask.predict(
|
y = dd.from_dask_array(y).repartition(npartitions=32)
|
||||||
client, booster, X, pred_contribs=True, validate_features=False
|
shap_df = xgb.dask.predict(
|
||||||
).compute()
|
client, booster, X, pred_contribs=True, validate_features=False
|
||||||
assert_shape(shap_df.shape)
|
).compute()
|
||||||
assert np.allclose(
|
assert_shape(shap_df.shape)
|
||||||
np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5
|
assert np.allclose(
|
||||||
)
|
np.sum(shap_df, axis=len(shap_df.shape) - 1), margin, 1e-5, 1e-5
|
||||||
|
)
|
||||||
|
|
||||||
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
def run_shap_cls_sklearn(self, X: Any, y: Any, client: "Client") -> None:
|
||||||
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
X, y = da.from_array(X, chunks=(32, -1)), da.from_array(y, chunks=32)
|
||||||
@ -1218,17 +1253,13 @@ class TestWithDask:
|
|||||||
np.testing.assert_allclose(predt_0.compute(), predt_3)
|
np.testing.assert_allclose(predt_0.compute(), predt_3)
|
||||||
|
|
||||||
|
|
||||||
def test_unsupported_features(client: "Client"):
|
def test_dask_unsupported_features(client: "Client") -> None:
|
||||||
X, y, _ = generate_array()
|
X, y, _ = generate_array()
|
||||||
# gblinear doesn't support distributed training.
|
# gblinear doesn't support distributed training.
|
||||||
with pytest.raises(NotImplementedError, match="gblinear"):
|
with pytest.raises(NotImplementedError, match="gblinear"):
|
||||||
xgb.dask.train(
|
xgb.dask.train(
|
||||||
client, {"booster": "gblinear"}, xgb.dask.DaskDMatrix(client, X, y)
|
client, {"booster": "gblinear"}, xgb.dask.DaskDMatrix(client, X, y)
|
||||||
)
|
)
|
||||||
# dart prediction is not thread safe, running predict with each partition will have
|
|
||||||
# race.
|
|
||||||
with pytest.raises(NotImplementedError, match="dart"):
|
|
||||||
xgb.dask.train(client, {"booster": "dart"}, xgb.dask.DaskDMatrix(client, X, y))
|
|
||||||
|
|
||||||
|
|
||||||
class TestDaskCallbacks:
|
class TestDaskCallbacks:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user