Refactor linear modelling and add new coordinate descent updater (#3103)
* Refactor linear modelling and add new coordinate descent updater * Allow unsorted column iterator * Add prediction cacheing to gblinear
This commit is contained in:
@@ -42,11 +42,18 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
std::remove(tmp_file.c_str());
|
||||
|
||||
EXPECT_EQ(dmat->HaveColAccess(), false);
|
||||
// Unsorted column access
|
||||
const std::vector<bool> enable(dmat->info().num_col, true);
|
||||
dmat->InitColAccess(enable, 1, dmat->info().num_row);
|
||||
dmat->InitColAccess(enable, 0, 0); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(), true);
|
||||
EXPECT_EQ(dmat->HaveColAccess(false), false);
|
||||
dmat->InitColAccess(enable, 1, dmat->info().num_row, false);
|
||||
dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(false), true);
|
||||
|
||||
// Sorted column access
|
||||
EXPECT_EQ(dmat->HaveColAccess(true), false);
|
||||
dmat->InitColAccess(enable, 1, dmat->info().num_row, true);
|
||||
dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(true), true);
|
||||
|
||||
EXPECT_EQ(dmat->GetColSize(0), 2);
|
||||
EXPECT_EQ(dmat->GetColSize(1), 1);
|
||||
@@ -86,11 +93,18 @@ TEST(SimpleDMatrix, ColAccessWithBatches) {
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
std::remove(tmp_file.c_str());
|
||||
|
||||
EXPECT_EQ(dmat->HaveColAccess(), false);
|
||||
// Unsorted column access
|
||||
const std::vector<bool> enable(dmat->info().num_col, true);
|
||||
dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch
|
||||
dmat->InitColAccess(enable, 0, 0); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(), true);
|
||||
EXPECT_EQ(dmat->HaveColAccess(false), false);
|
||||
dmat->InitColAccess(enable, 1, 1, false);
|
||||
dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(false), true);
|
||||
|
||||
// Sorted column access
|
||||
EXPECT_EQ(dmat->HaveColAccess(true), false);
|
||||
dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch
|
||||
dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(true), true);
|
||||
|
||||
EXPECT_EQ(dmat->GetColSize(0), 2);
|
||||
EXPECT_EQ(dmat->GetColSize(1), 1);
|
||||
|
||||
@@ -56,10 +56,10 @@ TEST(SparsePageDMatrix, ColAcess) {
|
||||
std::remove(tmp_file.c_str());
|
||||
EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page"));
|
||||
|
||||
EXPECT_EQ(dmat->HaveColAccess(), false);
|
||||
EXPECT_EQ(dmat->HaveColAccess(true), false);
|
||||
const std::vector<bool> enable(dmat->info().num_col, true);
|
||||
dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch
|
||||
ASSERT_EQ(dmat->HaveColAccess(), true);
|
||||
dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch
|
||||
ASSERT_EQ(dmat->HaveColAccess(true), true);
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.col.page"));
|
||||
|
||||
EXPECT_EQ(dmat->GetColSize(0), 2);
|
||||
|
||||
Reference in New Issue
Block a user