Refactor linear modelling and add new coordinate descent updater (#3103)
* Refactor linear modelling and add new coordinate descent updater * Allow unsorted column iterator * Add prediction cacheing to gblinear
This commit is contained in:
44
tests/cpp/linear/test_linear.cc
Normal file
44
tests/cpp/linear/test_linear.cc
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright by Contributors
|
||||
#include <xgboost/linear_updater.h>
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/gbm.h"
|
||||
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
|
||||
TEST(Linear, shotgun) {
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
auto mat = CreateDMatrix(10, 10, 0);
|
||||
std::vector<bool> enabled(mat->info().num_col, true);
|
||||
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("shotgun"));
|
||||
updater->Init({});
|
||||
std::vector<xgboost::bst_gpair> gpair(mat->info().num_row,
|
||||
xgboost::bst_gpair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
model.param.num_feature = mat->info().num_col;
|
||||
model.param.num_output_group = 1;
|
||||
model.LazyInitModel();
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
||||
|
||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||
}
|
||||
|
||||
TEST(Linear, coordinate) {
|
||||
typedef std::pair<std::string, std::string> arg;
|
||||
auto mat = CreateDMatrix(10, 10, 0);
|
||||
std::vector<bool> enabled(mat->info().num_col, true);
|
||||
mat->InitColAccess(enabled, 1.0f, 1 << 16, false);
|
||||
auto updater = std::unique_ptr<xgboost::LinearUpdater>(
|
||||
xgboost::LinearUpdater::Create("coord_descent"));
|
||||
updater->Init({});
|
||||
std::vector<xgboost::bst_gpair> gpair(mat->info().num_row,
|
||||
xgboost::bst_gpair(-5, 1.0));
|
||||
xgboost::gbm::GBLinearModel model;
|
||||
model.param.num_feature = mat->info().num_col;
|
||||
model.param.num_output_group = 1;
|
||||
model.LazyInitModel();
|
||||
updater->Update(&gpair, mat.get(), &model, gpair.size());
|
||||
|
||||
ASSERT_EQ(model.bias()[0], 5.0f);
|
||||
}
|
||||
Reference in New Issue
Block a user