[LIBXGBOOST] pass demo running.
This commit is contained in:
@@ -5,10 +5,10 @@
|
||||
* the update rule is parallel coordinate descent (shotgun)
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/gbm.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
@@ -17,6 +17,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gblinear);
|
||||
|
||||
// model parameter
|
||||
struct GBLinearModelParam :public dmlc::Parameter<GBLinearModelParam> {
|
||||
// number of feature dimension
|
||||
@@ -168,6 +171,9 @@ class GBLinear : public GradientBooster {
|
||||
int64_t buffer_offset,
|
||||
std::vector<float> *out_preds,
|
||||
unsigned ntree_limit) override {
|
||||
if (model.weight.size() == 0) {
|
||||
model.InitModel();
|
||||
}
|
||||
CHECK_EQ(ntree_limit, 0)
|
||||
<< "GBLinear::Predict ntrees is only valid for gbtree predictor";
|
||||
std::vector<float> &preds = *out_preds;
|
||||
@@ -293,4 +299,3 @@ XGBOOST_REGISTER_GBM(GBLinear, "gblinear")
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
29
src/gbm/gbm.cc
Normal file
29
src/gbm/gbm.cc
Normal file
@@ -0,0 +1,29 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file gbm.cc
|
||||
* \brief Registry of gradient boosters.
|
||||
*/
|
||||
#include <xgboost/gbm.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::GradientBoosterReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
GradientBooster* GradientBooster::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
LOG(FATAL) << "Unknown gbm type " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(gblinear);
|
||||
DMLC_REGISTRY_LINK_TAG(gbtree);
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
@@ -4,9 +4,9 @@
|
||||
* \brief gradient boosted tree implementation.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/gbm.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
namespace xgboost {
|
||||
namespace gbm {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(gbtree);
|
||||
|
||||
/*! \brief training parameters */
|
||||
struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
|
||||
/*! \brief number of threads */
|
||||
@@ -482,4 +484,3 @@ XGBOOST_REGISTER_GBM(GBTree, "gbtree")
|
||||
});
|
||||
} // namespace gbm
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
Reference in New Issue
Block a user