Use ellpack for prediction only when sparsepage doesn't exist. (#5504)

This commit is contained in:
Jiaming Yuan
2020-04-10 12:15:46 +08:00
committed by GitHub
parent ad826e913f
commit 6671b42dd4
35 changed files with 166 additions and 116 deletions

View File

@@ -68,7 +68,7 @@ TEST(Adapter, CSCAdapterColsMoreThanRows) {
}
TEST(CAPI, DMatrixSliceAdapterFromSimpleDMatrix) {
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatix();
auto p_dmat = RandomDataGenerator(6, 2, 1.0).GenerateDMatrix();
std::vector<int> ridx_set = {1, 3, 5};
data::DMatrixSliceAdapter adapter(p_dmat.get(),

View File

@@ -23,7 +23,7 @@ TEST(DeviceDMatrix, RowMajor) {
auto adapter = common::AdapterFromData(x_device, num_rows, num_columns);
data::DeviceDMatrix dmat(&adapter,
std::numeric_limits<float>::quiet_NaN(), 1, 256);
std::numeric_limits<float>::quiet_NaN(), 1, 256);
auto &batch = *dmat.GetBatches<EllpackPage>({0, 256, 0}).begin();
auto impl = batch.Impl();
@@ -60,7 +60,7 @@ TEST(DeviceDMatrix, RowMajorMissing) {
EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue());
EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue());
// null values get placed after valid values in a row
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
EXPECT_EQ(iterator[7], impl->GetDeviceAccessor(0).NullValue());
EXPECT_EQ(dmat.Info().num_col_, num_columns);
EXPECT_EQ(dmat.Info().num_row_, num_rows);
EXPECT_EQ(dmat.Info().num_nonzero_, num_rows*num_columns-3);

View File

@@ -17,7 +17,7 @@ namespace xgboost {
TEST(EllpackPage, EmptyDMatrix) {
constexpr int kNRows = 0, kNCols = 0, kMaxBin = 256;
constexpr float kSparsity = 0;
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatix();
auto dmat = RandomDataGenerator(kNRows, kNCols, kSparsity).GenerateDMatrix();
auto& page = *dmat->GetBatches<EllpackPage>({0, kMaxBin}).begin();
auto impl = page.Impl();
ASSERT_EQ(impl->row_stride, 0);

View File

@@ -220,7 +220,7 @@ TEST(SimpleDMatrix, FromFile) {
TEST(SimpleDMatrix, Slice) {
const int kRows = 6;
const int kCols = 2;
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatix();
auto p_dmat = RandomDataGenerator(kRows, kCols, 1.0).GenerateDMatrix();
auto &labels = p_dmat->Info().labels_.HostVector();
auto &weights = p_dmat->Info().weights_.HostVector();
auto &base_margin = p_dmat->Info().base_margin_.HostVector();