Multiclass prediction caching for CPU Hist (#6550)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
@@ -110,8 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) {
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
@@ -125,6 +124,23 @@ bool QuantileHistMaker::UpdatePredictionCache(
|
||||
}
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
|
||||
int,
|
||||
@@ -620,7 +636,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds) {
|
||||
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
@@ -659,7 +675,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds[*it] += leaf_value;
|
||||
out_preds[*it * ngroup + gid] += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -121,6 +121,9 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) override;
|
||||
bool UpdatePredictionCacheMulticlass(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const int gid, const int ngroup) override;
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
@@ -243,7 +246,9 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds);
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
const int gid = 0, const int ngroup = 1);
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user