Predictor for vector leaf. (#8898)
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
/**
|
||||
* Copyright 2020-2023 by XGBoost Contributors
|
||||
*/
|
||||
#ifndef XGBOOST_TEST_PREDICTOR_H_
|
||||
#define XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
#include <xgboost/context.h> // for Context
|
||||
#include <xgboost/predictor.h>
|
||||
#include <string>
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
|
||||
#include "../../../src/gbm/gbtree_model.h" // for GBTreeModel
|
||||
#include "../helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
@@ -48,7 +55,7 @@ void TestPredictionFromGradientIndex(std::string name, size_t rows, size_t cols,
|
||||
PredictionCacheEntry precise_out_predictions;
|
||||
predictor->InitOutPredictions(p_dmat->Info(), &precise_out_predictions.predictions, model);
|
||||
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
||||
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
||||
CHECK(!p_dmat->PageExists<Page>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,6 +76,8 @@ void TestCategoricalPredictLeaf(StringView name);
|
||||
void TestIterationRange(std::string name);
|
||||
|
||||
void TestSparsePrediction(float sparsity, std::string predictor);
|
||||
|
||||
void TestVectorLeafPrediction(Context const* ctx);
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_TEST_PREDICTOR_H_
|
||||
|
||||
Reference in New Issue
Block a user