Use ellpack for prediction only when sparsepage doesn't exist. (#5504)
This commit is contained in:
@@ -31,7 +31,7 @@ TEST(GPUPredictor, Basic) {
|
||||
|
||||
for (size_t i = 1; i < 33; i *= 2) {
|
||||
int n_row = i, n_col = i;
|
||||
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatix();
|
||||
auto dmat = RandomDataGenerator(n_row, n_col, 0).GenerateDMatrix();
|
||||
|
||||
LearnerModelParam param;
|
||||
param.num_feature = n_col;
|
||||
@@ -58,16 +58,33 @@ TEST(GPUPredictor, Basic) {
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackBasic) {
|
||||
size_t constexpr kCols {8};
|
||||
for (size_t bins = 2; bins < 258; bins += 16) {
|
||||
size_t rows = bins * 16;
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, bins);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, bins);
|
||||
auto p_m = RandomDataGenerator{rows, kCols, 0.0}
|
||||
.Bins(bins)
|
||||
.Device(0)
|
||||
.GenerateDeviceDMatrix(true);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", rows, kCols, p_m);
|
||||
TestPredictionFromGradientIndex<EllpackPage>("gpu_predictor", bins, kCols, p_m);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, EllpackTraining) {
|
||||
size_t constexpr kRows { 128 };
|
||||
TestTrainingPrediction(kRows, "gpu_hist");
|
||||
size_t constexpr kRows { 128 }, kCols { 16 }, kBins { 64 };
|
||||
auto p_ellpack = RandomDataGenerator{kRows, kCols, 0.0}
|
||||
.Bins(kBins)
|
||||
.Device(0)
|
||||
.GenerateDeviceDMatrix(true);
|
||||
std::vector<HostDeviceVector<float>> storage(kCols);
|
||||
auto columnar = RandomDataGenerator{kRows, kCols, 0.0}
|
||||
.Device(0)
|
||||
.GenerateColumnarArrayInterface(&storage);
|
||||
auto adapter = data::CudfAdapter(columnar);
|
||||
std::shared_ptr<DMatrix> p_full {
|
||||
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1)
|
||||
};
|
||||
TestTrainingPrediction(kRows, "gpu_hist", p_full, p_ellpack);
|
||||
}
|
||||
|
||||
TEST(GPUPredictor, ExternalMemoryTest) {
|
||||
|
||||
Reference in New Issue
Block a user