Refactor linear modelling and add new coordinate descent updater (#3103)
* Refactor linear modelling and add new coordinate descent updater * Allow unsorted column iterator * Add prediction cacheing to gblinear
This commit is contained in:
@@ -274,14 +274,16 @@ class DMatrix {
|
||||
* \param subsample subsample ratio when generating column access.
|
||||
* \param max_row_perbatch auxiliary information, maximum row used in each column batch.
|
||||
* this is a hint information that can be ignored by the implementation.
|
||||
* \param sorted If column features should be in sorted order
|
||||
* \return Number of column blocks in the column access.
|
||||
*/
|
||||
|
||||
virtual void InitColAccess(const std::vector<bool>& enabled,
|
||||
float subsample,
|
||||
size_t max_row_perbatch) = 0;
|
||||
size_t max_row_perbatch, bool sorted) = 0;
|
||||
// the following are column meta data, should be able to answer them fast.
|
||||
/*! \return whether column access is enabled */
|
||||
virtual bool HaveColAccess() const = 0;
|
||||
virtual bool HaveColAccess(bool sorted) const = 0;
|
||||
/*! \return Whether the data columns single column block. */
|
||||
virtual bool SingleColBlock() const = 0;
|
||||
/*! \brief get number of non-missing entries in column */
|
||||
|
||||
66
include/xgboost/linear_updater.h
Normal file
66
include/xgboost/linear_updater.h
Normal file
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
* Copyright 2018 by Contributors
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
#include <xgboost/base.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "../../src/gbm/gblinear_model.h"
|
||||
|
||||
namespace xgboost {
|
||||
/*!
|
||||
* \brief interface of linear updater
|
||||
*/
|
||||
class LinearUpdater {
|
||||
public:
|
||||
/*! \brief virtual destructor */
|
||||
virtual ~LinearUpdater() {}
|
||||
/*!
|
||||
* \brief Initialize the updater with given arguments.
|
||||
* \param args arguments to the objective function.
|
||||
*/
|
||||
virtual void Init(
|
||||
const std::vector<std::pair<std::string, std::string> >& args) = 0;
|
||||
|
||||
/**
|
||||
* \brief Updates linear model given gradients.
|
||||
*
|
||||
* \param in_gpair The gradient pair statistics of the data.
|
||||
* \param data Input data matrix.
|
||||
* \param model Model to be updated.
|
||||
* \param sum_instance_weight The sum instance weights, used to normalise l1/l2 penalty.
|
||||
*/
|
||||
|
||||
virtual void Update(std::vector<bst_gpair>* in_gpair, DMatrix* data,
|
||||
gbm::GBLinearModel* model,
|
||||
double sum_instance_weight) = 0;
|
||||
|
||||
/*!
|
||||
* \brief Create a linear updater given name
|
||||
* \param name Name of the linear updater.
|
||||
*/
|
||||
static LinearUpdater* Create(const std::string& name);
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Registry entry for linear updater.
|
||||
*/
|
||||
struct LinearUpdaterReg
|
||||
: public dmlc::FunctionRegEntryBase<LinearUpdaterReg,
|
||||
std::function<LinearUpdater*()> > {};
|
||||
|
||||
/*!
|
||||
* \brief Macro to register linear updater.
|
||||
*/
|
||||
#define XGBOOST_REGISTER_LINEAR_UPDATER(UniqueId, Name) \
|
||||
static DMLC_ATTRIBUTE_UNUSED ::xgboost::LinearUpdaterReg& \
|
||||
__make_##LinearUpdaterReg##_##UniqueId##__ = \
|
||||
::dmlc::Registry< ::xgboost::LinearUpdaterReg>::Get()->__REGISTER__( \
|
||||
Name)
|
||||
|
||||
} // namespace xgboost
|
||||
Reference in New Issue
Block a user