Normal prediction with DMatrix is now thread safe with locks. Added inplace prediction is lock free thread safe. When data is on device (cupy, cudf), the returned data is also on device. * Implementation for numpy, csr, cudf and cupy. * Implementation for dask. * Remove sync in simple dmatrix.
67 lines
2.2 KiB
C++
67 lines
2.2 KiB
C++
#ifndef XGBOOST_TEST_PREDICTOR_H_
|
|
#define XGBOOST_TEST_PREDICTOR_H_
|
|
|
|
#include <xgboost/predictor.h>
|
|
#include <string>
|
|
#include <cstddef>
|
|
#include "../helpers.h"
|
|
|
|
namespace xgboost {
|
|
template <typename Page>
|
|
void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) {
|
|
constexpr size_t kCols { 8 }, kClasses { 3 };
|
|
|
|
LearnerModelParam param;
|
|
param.num_feature = kCols;
|
|
param.num_output_group = kClasses;
|
|
param.base_score = 0.5;
|
|
|
|
auto lparam = CreateEmptyGenericParam(0);
|
|
|
|
std::unique_ptr<Predictor> predictor =
|
|
std::unique_ptr<Predictor>(Predictor::Create(name, &lparam));
|
|
predictor->Configure({});
|
|
|
|
gbm::GBTreeModel model = CreateTestModel(¶m, kClasses);
|
|
|
|
{
|
|
auto p_ellpack = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
|
// Use same number of bins as rows.
|
|
for (auto const &page DMLC_ATTRIBUTE_UNUSED :
|
|
p_ellpack->GetBatches<Page>({0, static_cast<int32_t>(bins), 0})) {
|
|
}
|
|
|
|
auto p_precise = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
|
|
|
PredictionCacheEntry approx_out_predictions;
|
|
predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0);
|
|
|
|
PredictionCacheEntry precise_out_predictions;
|
|
predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0);
|
|
|
|
for (size_t i = 0; i < rows; ++i) {
|
|
CHECK_EQ(approx_out_predictions.predictions.HostVector()[i],
|
|
precise_out_predictions.predictions.HostVector()[i]);
|
|
}
|
|
}
|
|
|
|
{
|
|
// Predictor should never try to create the histogram index by itself. As only
|
|
// histogram index from training data is valid and predictor doesn't known which
|
|
// matrix is used for training.
|
|
auto p_dmat = RandomDataGenerator(rows, kCols, 0).GenerateDMatix();
|
|
PredictionCacheEntry precise_out_predictions;
|
|
predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0);
|
|
ASSERT_FALSE(p_dmat->PageExists<Page>());
|
|
}
|
|
}
|
|
|
|
void TestTrainingPrediction(size_t rows, std::string tree_method);
|
|
|
|
void TestInplacePrediction(dmlc::any x, std::string predictor,
|
|
bst_row_t rows, bst_feature_t cols,
|
|
int32_t device = -1);
|
|
} // namespace xgboost
|
|
|
|
#endif // XGBOOST_TEST_PREDICTOR_H_
|