Multiclass prediction caching for CPU Hist (#6550)
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
parent
03cd087da1
commit
7f4d3a91b9
@ -70,11 +70,14 @@ class TreeUpdater : public Configurable {
|
||||
* the prediction cache. If true, the prediction cache will have been
|
||||
* updated by the time this function returns.
|
||||
*/
|
||||
virtual bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) {
|
||||
// Remove unused parameter compiler warning.
|
||||
(void) data;
|
||||
(void) out_preds;
|
||||
virtual bool UpdatePredictionCache(const DMatrix* /*data*/,
|
||||
HostDeviceVector<bst_float>* /*out_preds*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool UpdatePredictionCacheMulticlass(const DMatrix* /*data*/,
|
||||
HostDeviceVector<bst_float>* /*out_preds*/,
|
||||
const int /*gid*/, const int /*ngroup*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@ -196,11 +196,17 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
const int ngroup = model_.learner_model_param->num_output_group;
|
||||
ConfigureWithKnownData(this->cfg_, p_fmat);
|
||||
monitor_.Start("BoostNewTrees");
|
||||
auto* out = &predt->predictions;
|
||||
CHECK_NE(ngroup, 0);
|
||||
if (ngroup == 1) {
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(in_gpair, p_fmat, 0, &ret);
|
||||
const size_t num_new_trees = ret.size();
|
||||
new_trees.push_back(std::move(ret));
|
||||
if (updaters_.size() > 0 && num_new_trees == 1 && out->Size() > 0 &&
|
||||
updaters_.back()->UpdatePredictionCache(p_fmat, out)) {
|
||||
predt->Update(1);
|
||||
}
|
||||
} else {
|
||||
CHECK_EQ(in_gpair->Size() % ngroup, 0U)
|
||||
<< "must have exactly ngroup * nrow gpairs";
|
||||
@ -210,6 +216,7 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
in_gpair->DeviceIdx());
|
||||
const auto& gpair_h = in_gpair->ConstHostVector();
|
||||
auto nsize = static_cast<bst_omp_uint>(tmp.Size());
|
||||
bool update_predict = true;
|
||||
for (int gid = 0; gid < ngroup; ++gid) {
|
||||
std::vector<GradientPair>& tmp_h = tmp.HostVector();
|
||||
#pragma omp parallel for schedule(static)
|
||||
@ -218,7 +225,16 @@ void GBTree::DoBoost(DMatrix* p_fmat,
|
||||
}
|
||||
std::vector<std::unique_ptr<RegTree> > ret;
|
||||
BoostNewTrees(&tmp, p_fmat, gid, &ret);
|
||||
const size_t num_new_trees = ret.size();
|
||||
new_trees.push_back(std::move(ret));
|
||||
auto* out = &predt->predictions;
|
||||
if (!(updaters_.size() > 0 && out->Size() > 0 && num_new_trees == 1 &&
|
||||
updaters_.back()->UpdatePredictionCacheMulticlass(p_fmat, out, gid, ngroup))) {
|
||||
update_predict = false;
|
||||
}
|
||||
}
|
||||
if (update_predict) {
|
||||
predt->Update(1);
|
||||
}
|
||||
}
|
||||
monitor_.Stop("BoostNewTrees");
|
||||
@ -314,20 +330,9 @@ void GBTree::CommitModel(std::vector<std::vector<std::unique_ptr<RegTree>>>&& ne
|
||||
DMatrix* m,
|
||||
PredictionCacheEntry* predts) {
|
||||
monitor_.Start("CommitModel");
|
||||
int num_new_trees = 0;
|
||||
for (uint32_t gid = 0; gid < model_.learner_model_param->num_output_group; ++gid) {
|
||||
num_new_trees += new_trees[gid].size();
|
||||
model_.CommitModel(std::move(new_trees[gid]), gid);
|
||||
}
|
||||
auto* out = &predts->predictions;
|
||||
if (model_.learner_model_param->num_output_group == 1 &&
|
||||
updaters_.size() > 0 &&
|
||||
num_new_trees == 1 &&
|
||||
out->Size() > 0 &&
|
||||
updaters_.back()->UpdatePredictionCache(m, out)) {
|
||||
auto delta = num_new_trees / model_.learner_model_param->num_output_group;
|
||||
predts->Update(delta);
|
||||
}
|
||||
monitor_.Stop("CommitModel");
|
||||
}
|
||||
|
||||
|
||||
@ -110,8 +110,7 @@ void QuantileHistMaker::Update(HostDeviceVector<GradientPair> *gpair,
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) {
|
||||
const DMatrix* data, HostDeviceVector<bst_float>* out_preds) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
@ -125,6 +124,23 @@ bool QuantileHistMaker::UpdatePredictionCache(
|
||||
}
|
||||
}
|
||||
|
||||
bool QuantileHistMaker::UpdatePredictionCacheMulticlass(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds, const int gid, const int ngroup) {
|
||||
if (param_.subsample < 1.0f) {
|
||||
return false;
|
||||
} else {
|
||||
if (hist_maker_param_.single_precision_histogram && float_builder_) {
|
||||
return float_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else if (double_builder_) {
|
||||
return double_builder_->UpdatePredictionCache(data, out_preds, gid, ngroup);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename GradientSumT>
|
||||
void BatchHistSynchronizer<GradientSumT>::SyncHistograms(BuilderT *builder,
|
||||
int,
|
||||
@ -620,7 +636,7 @@ void QuantileHistMaker::Builder<GradientSumT>::Update(
|
||||
template<typename GradientSumT>
|
||||
bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds) {
|
||||
HostDeviceVector<bst_float>* p_out_preds, const int gid, const int ngroup) {
|
||||
// p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in
|
||||
// conjunction with Update().
|
||||
if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) {
|
||||
@ -659,7 +675,7 @@ bool QuantileHistMaker::Builder<GradientSumT>::UpdatePredictionCache(
|
||||
leaf_value = (*p_last_tree_)[nid].LeafValue();
|
||||
|
||||
for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) {
|
||||
out_preds[*it] += leaf_value;
|
||||
out_preds[*it * ngroup + gid] += leaf_value;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@ -121,6 +121,9 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds) override;
|
||||
bool UpdatePredictionCacheMulticlass(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* out_preds,
|
||||
const int gid, const int ngroup) override;
|
||||
|
||||
void LoadConfig(Json const& in) override {
|
||||
auto const& config = get<Object const>(in);
|
||||
@ -243,7 +246,9 @@ class QuantileHistMaker: public TreeUpdater {
|
||||
}
|
||||
|
||||
bool UpdatePredictionCache(const DMatrix* data,
|
||||
HostDeviceVector<bst_float>* p_out_preds);
|
||||
HostDeviceVector<bst_float>* p_out_preds,
|
||||
const int gid = 0, const int ngroup = 1);
|
||||
|
||||
void SetHistSynchronizer(HistSynchronizer<GradientSumT>* sync);
|
||||
void SetHistRowsAdder(HistRowsAdder<GradientSumT>* adder);
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
#include "../helpers.h"
|
||||
#include "test_predictor.h"
|
||||
#include "../../../src/gbm/gbtree_model.h"
|
||||
#include "../../../src/gbm/gbtree.h"
|
||||
#include "../../../src/data/adapter.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -180,6 +181,49 @@ TEST(CpuPredictor, InplacePredict) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, UpdatePredictionCache) {
|
||||
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
|
||||
LearnerModelParam mparam;
|
||||
mparam.num_feature = kCols;
|
||||
mparam.num_output_group = kClasses;
|
||||
mparam.base_score = 0;
|
||||
|
||||
GenericParameter gparam;
|
||||
gparam.Init(Args{});
|
||||
|
||||
std::unique_ptr<gbm::GBTree> gbm;
|
||||
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &gparam, &mparam)));
|
||||
std::map<std::string, std::string> cfg;
|
||||
cfg["tree_method"] = "hist";
|
||||
cfg["predictor"] = "cpu_predictor";
|
||||
Args args = {cfg.cbegin(), cfg.cend()};
|
||||
gbm->Configure(args);
|
||||
|
||||
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
|
||||
|
||||
HostDeviceVector<GradientPair> gpair;
|
||||
auto& h_gpair = gpair.HostVector();
|
||||
h_gpair.resize(kRows*kClasses);
|
||||
for (size_t i = 0; i < kRows*kClasses; ++i) {
|
||||
h_gpair[i] = {static_cast<float>(i), 1};
|
||||
}
|
||||
|
||||
PredictionCacheEntry predtion_cache;
|
||||
predtion_cache.predictions.Resize(kRows*kClasses, 0);
|
||||
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values
|
||||
gbm->DoBoost(dmat.get(), &gpair, &predtion_cache);
|
||||
|
||||
PredictionCacheEntry out_predictions;
|
||||
// perform fair prediction on the same input data, should be equal to cached result
|
||||
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0);
|
||||
|
||||
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
|
||||
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();
|
||||
for (size_t i = 0; i < out_predictions_h.size(); ++i) {
|
||||
ASSERT_NEAR(out_predictions_h[i], predtion_cache_from_train[i], kRtEps);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, LesserFeatures) {
|
||||
TestPredictionWithLesserFeatures("cpu_predictor");
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user