Multiclass prediction caching for CPU Hist (#6550)

Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
This commit is contained in:
ShvetsKS
2021-01-12 23:42:07 +03:00
committed by GitHub
parent 03cd087da1
commit 7f4d3a91b9
5 changed files with 94 additions and 21 deletions

View File

@@ -8,6 +8,7 @@
#include "../helpers.h"
#include "test_predictor.h"
#include "../../../src/gbm/gbtree_model.h"
#include "../../../src/gbm/gbtree.h"
#include "../../../src/data/adapter.h"
namespace xgboost {
@@ -180,6 +181,49 @@ TEST(CpuPredictor, InplacePredict) {
}
}
TEST(CpuPredictor, UpdatePredictionCache) {
size_t constexpr kRows = 64, kCols = 16, kClasses = 4;
LearnerModelParam mparam;
mparam.num_feature = kCols;
mparam.num_output_group = kClasses;
mparam.base_score = 0;
GenericParameter gparam;
gparam.Init(Args{});
std::unique_ptr<gbm::GBTree> gbm;
gbm.reset(static_cast<gbm::GBTree*>(GradientBooster::Create("gbtree", &gparam, &mparam)));
std::map<std::string, std::string> cfg;
cfg["tree_method"] = "hist";
cfg["predictor"] = "cpu_predictor";
Args args = {cfg.cbegin(), cfg.cend()};
gbm->Configure(args);
auto dmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(true, true, kClasses);
HostDeviceVector<GradientPair> gpair;
auto& h_gpair = gpair.HostVector();
h_gpair.resize(kRows*kClasses);
for (size_t i = 0; i < kRows*kClasses; ++i) {
h_gpair[i] = {static_cast<float>(i), 1};
}
PredictionCacheEntry predtion_cache;
predtion_cache.predictions.Resize(kRows*kClasses, 0);
// after one training iteration predtion_cache is filled with cached in QuantileHistMaker::Builder prediction values
gbm->DoBoost(dmat.get(), &gpair, &predtion_cache);
PredictionCacheEntry out_predictions;
// perform fair prediction on the same input data, should be equal to cached result
gbm->PredictBatch(dmat.get(), &out_predictions, false, 0);
std::vector<float> &out_predictions_h = out_predictions.predictions.HostVector();
std::vector<float> &predtion_cache_from_train = predtion_cache.predictions.HostVector();
for (size_t i = 0; i < out_predictions_h.size(); ++i) {
ASSERT_NEAR(out_predictions_h[i], predtion_cache_from_train[i], kRtEps);
}
}
TEST(CpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("cpu_predictor");
}