sync upstream code
This commit is contained in:
@@ -65,7 +65,7 @@ TEST(CpuPredictor, ExternalMemory) {
|
||||
}
|
||||
|
||||
TEST(CpuPredictor, InplacePredict) {
|
||||
bst_row_t constexpr kRows{128};
|
||||
bst_idx_t constexpr kRows{128};
|
||||
bst_feature_t constexpr kCols{64};
|
||||
Context ctx;
|
||||
auto gen = RandomDataGenerator{kRows, kCols, 0.5}.Device(ctx.Device());
|
||||
@@ -83,7 +83,7 @@ TEST(CpuPredictor, InplacePredict) {
|
||||
|
||||
{
|
||||
HostDeviceVector<float> data;
|
||||
HostDeviceVector<bst_row_t> rptrs;
|
||||
HostDeviceVector<std::size_t> rptrs;
|
||||
HostDeviceVector<bst_feature_t> columns;
|
||||
gen.GenerateCSR(&data, &rptrs, &columns);
|
||||
auto data_interface = GetArrayInterface(&data, kRows * kCols, 1);
|
||||
|
||||
@@ -186,7 +186,7 @@ void TestTrainingPrediction(Context const *ctx, size_t rows, size_t bins,
|
||||
}
|
||||
}
|
||||
|
||||
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
|
||||
void TestInplacePrediction(Context const *ctx, std::shared_ptr<DMatrix> x, bst_idx_t rows,
|
||||
bst_feature_t cols) {
|
||||
std::size_t constexpr kClasses { 4 };
|
||||
auto gen = RandomDataGenerator{rows, cols, 0.5}.Device(ctx->Device());
|
||||
@@ -255,7 +255,7 @@ std::unique_ptr<Learner> LearnerForTest(Context const *ctx, std::shared_ptr<DMat
|
||||
return learner;
|
||||
}
|
||||
|
||||
void VerifyPredictionWithLesserFeatures(Learner *learner, bst_row_t kRows,
|
||||
void VerifyPredictionWithLesserFeatures(Learner *learner, bst_idx_t kRows,
|
||||
std::shared_ptr<DMatrix> m_test,
|
||||
std::shared_ptr<DMatrix> m_invalid) {
|
||||
HostDeviceVector<float> prediction;
|
||||
|
||||
@@ -92,7 +92,7 @@ void TestTrainingPrediction(Context const* ctx, size_t rows, size_t bins,
|
||||
std::shared_ptr<DMatrix> p_full, std::shared_ptr<DMatrix> p_hist,
|
||||
bool check_contribs = false);
|
||||
|
||||
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_row_t rows,
|
||||
void TestInplacePrediction(Context const* ctx, std::shared_ptr<DMatrix> x, bst_idx_t rows,
|
||||
bst_feature_t cols);
|
||||
|
||||
void TestPredictionWithLesserFeatures(Context const* ctx);
|
||||
|
||||
Reference in New Issue
Block a user