Refactor linear modelling and add new coordinate descent updater (#3103)
* Refactor linear modelling and add new coordinate descent updater * Allow unsorted column iterator * Add prediction cacheing to gblinear
This commit is contained in:
@@ -42,11 +42,18 @@ TEST(SimpleDMatrix, ColAccessWithoutBatches) {
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
std::remove(tmp_file.c_str());
|
||||
|
||||
EXPECT_EQ(dmat->HaveColAccess(), false);
|
||||
// Unsorted column access
|
||||
const std::vector<bool> enable(dmat->info().num_col, true);
|
||||
dmat->InitColAccess(enable, 1, dmat->info().num_row);
|
||||
dmat->InitColAccess(enable, 0, 0); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(), true);
|
||||
EXPECT_EQ(dmat->HaveColAccess(false), false);
|
||||
dmat->InitColAccess(enable, 1, dmat->info().num_row, false);
|
||||
dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(false), true);
|
||||
|
||||
// Sorted column access
|
||||
EXPECT_EQ(dmat->HaveColAccess(true), false);
|
||||
dmat->InitColAccess(enable, 1, dmat->info().num_row, true);
|
||||
dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(true), true);
|
||||
|
||||
EXPECT_EQ(dmat->GetColSize(0), 2);
|
||||
EXPECT_EQ(dmat->GetColSize(1), 1);
|
||||
@@ -86,11 +93,18 @@ TEST(SimpleDMatrix, ColAccessWithBatches) {
|
||||
xgboost::DMatrix * dmat = xgboost::DMatrix::Load(tmp_file, true, false);
|
||||
std::remove(tmp_file.c_str());
|
||||
|
||||
EXPECT_EQ(dmat->HaveColAccess(), false);
|
||||
// Unsorted column access
|
||||
const std::vector<bool> enable(dmat->info().num_col, true);
|
||||
dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch
|
||||
dmat->InitColAccess(enable, 0, 0); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(), true);
|
||||
EXPECT_EQ(dmat->HaveColAccess(false), false);
|
||||
dmat->InitColAccess(enable, 1, 1, false);
|
||||
dmat->InitColAccess(enable, 0, 0, false); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(false), true);
|
||||
|
||||
// Sorted column access
|
||||
EXPECT_EQ(dmat->HaveColAccess(true), false);
|
||||
dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch
|
||||
dmat->InitColAccess(enable, 0, 0, true); // Calling it again should not change it
|
||||
ASSERT_EQ(dmat->HaveColAccess(true), true);
|
||||
|
||||
EXPECT_EQ(dmat->GetColSize(0), 2);
|
||||
EXPECT_EQ(dmat->GetColSize(1), 1);
|
||||
|
||||
@@ -56,10 +56,10 @@ TEST(SparsePageDMatrix, ColAcess) {
|
||||
std::remove(tmp_file.c_str());
|
||||
EXPECT_FALSE(FileExists(tmp_file + ".cache.col.page"));
|
||||
|
||||
EXPECT_EQ(dmat->HaveColAccess(), false);
|
||||
EXPECT_EQ(dmat->HaveColAccess(true), false);
|
||||
const std::vector<bool> enable(dmat->info().num_col, true);
|
||||
dmat->InitColAccess(enable, 1, 1); // Max 1 row per patch
|
||||
ASSERT_EQ(dmat->HaveColAccess(), true);
|
||||
dmat->InitColAccess(enable, 1, 1, true); // Max 1 row per patch
|
||||
ASSERT_EQ(dmat->HaveColAccess(true), true);
|
||||
EXPECT_TRUE(FileExists(tmp_file + ".cache.col.page"));
|
||||
|
||||
EXPECT_EQ(dmat->GetColSize(0), 2);
|
||||
|
||||
44
tests/cpp/linear/test_linear.cc
Normal file
44
tests/cpp/linear/test_linear.cc
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright by Contributors
|
||||
#include <xgboost/linear_updater.h>
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/gbm.h"
|
||||
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
|
||||
TEST(Linear, shotgun) {
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
auto mat = CreateDMatrix(10, 10, 0);
|
||||
std::vector<bool> enabled(mat->info().num_col, true);
|
||||
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("shotgun"));
|
||||
updater->Init({});
|
||||
std::vector<xgboost::bst_gpair> gpair(mat->info().num_row,
|
||||
xgboost::bst_gpair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
model.param.num_feature = mat->info().num_col;
|
||||
model.param.num_output_group = 1;
|
||||
model.LazyInitModel();
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
||||
|
||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||
}
|
||||
|
||||
TEST(Linear, coordinate) {
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
auto mat = CreateDMatrix(10, 10, 0);
|
||||
std::vector<bool> enabled(mat->info().num_col, true);
|
||||
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("coord_descent"));
|
||||
updater->Init({});
|
||||
std::vector<xgboost::bst_gpair> gpair(mat->info().num_row,
|
||||
xgboost::bst_gpair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
model.param.num_feature = mat->info().num_col;
|
||||
model.param.num_output_group = 1;
|
||||
model.LazyInitModel();
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
||||
|
||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||
}
|
||||
Reference in New Issue
Block a user