merge latest changes
This commit is contained in:
@@ -148,7 +148,7 @@ TEST(CPUPredictor, GHistIndexTraining) {
|
||||
auto adapter = data::ArrayAdapter(columnar.c_str());
|
||||
std::shared_ptr<DMatrix> p_full{
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)};
|
||||
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist);
|
||||
TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_hist, true);
|
||||
}
|
||||
|
||||
TEST(CPUPredictor, CategoricalPrediction) {
|
||||
|
||||
@@ -118,7 +118,8 @@ TEST(Predictor, PredictionCache) {
|
||||
}
|
||||
|
||||
void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
|
||||
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist) {
|
||||
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
|
||||
bool check_contribs) {
|
||||
size_t constexpr kCols = 16;
|
||||
size_t constexpr kClasses = 3;
|
||||
size_t constexpr kIters = 3;
|
||||
@@ -161,6 +162,28 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
|
||||
for (size_t i = 0; i < rows; ++i) {
|
||||
EXPECT_NEAR(from_hist.ConstHostVector()[i], from_full.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
|
||||
if (check_contribs) {
|
||||
// Contributions
|
||||
HostDeviceVector<float> from_full_contribs;
|
||||
learner->Predict(p_full, false, &from_full_contribs, 0, 0, false, false, true);
|
||||
HostDeviceVector<float> from_hist_contribs;
|
||||
learner->Predict(p_hist, false, &from_hist_contribs, 0, 0, false, false, true);
|
||||
for (size_t i = 0; i < from_full_contribs.ConstHostVector().size(); ++i) {
|
||||
EXPECT_NEAR(from_hist_contribs.ConstHostVector()[i],
|
||||
from_full_contribs.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
|
||||
// Contributions (approximate method)
|
||||
HostDeviceVector<float> from_full_approx_contribs;
|
||||
learner->Predict(p_full, false, &from_full_approx_contribs, 0, 0, false, false, false, true);
|
||||
HostDeviceVector<float> from_hist_approx_contribs;
|
||||
learner->Predict(p_hist, false, &from_hist_approx_contribs, 0, 0, false, false, false, true);
|
||||
for (size_t i = 0; i < from_full_approx_contribs.ConstHostVector().size(); ++i) {
|
||||
EXPECT_NEAR(from_hist_approx_contribs.ConstHostVector()[i],
|
||||
from_full_approx_contribs.ConstHostVector()[i], kRtEps);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
|
||||
|
||||
@@ -89,7 +89,8 @@ void TestBasic(DMatrix* dmat, Context const * ctx);
|
||||
|
||||
// p_full and p_hist should come from the same data set.
|
||||
void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins,
|
||||
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist);
|
||||
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
|
||||
bool check_contribs = false);
|
||||
|
||||
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
|
||||
bst_feature_t cols);
|
||||
|
||||
Reference in New Issue
Block a user