[LIBXGBOOST] pass demo running.
This commit is contained in:
@@ -4,9 +4,9 @@
|
||||
* \brief Definition of multi-class classification objectives.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
@@ -16,6 +16,8 @@
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj);
|
||||
|
||||
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
int num_class;
|
||||
// declare parameters
|
||||
|
||||
34
src/objective/objective.cc
Normal file
34
src/objective/objective.cc
Normal file
@@ -0,0 +1,34 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file objective.cc
|
||||
* \brief Registry of all objective functions.
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include <dmlc/registry.h>
|
||||
|
||||
namespace dmlc {
|
||||
DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
|
||||
} // namespace dmlc
|
||||
|
||||
namespace xgboost {
|
||||
// implement factory functions
|
||||
ObjFunction* ObjFunction::Create(const std::string& name) {
|
||||
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
|
||||
if (e == nullptr) {
|
||||
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
|
||||
LOG(INFO) << "Objective candidate: " << entry->name;
|
||||
}
|
||||
LOG(FATAL) << "Unknown objective function " << name;
|
||||
}
|
||||
return (e->body)();
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@@ -4,8 +4,8 @@
|
||||
* \brief Definition of rank loss.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
@@ -16,6 +16,8 @@
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(rank_obj);
|
||||
|
||||
struct LambdaRankParam : public dmlc::Parameter<LambdaRankParam> {
|
||||
int num_pairsample;
|
||||
float fix_list_weight;
|
||||
@@ -324,4 +326,3 @@ XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, "rank:map")
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
* \brief Definition of single-value regression and classification objectives.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
#include <dmlc/logging.h>
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
@@ -14,6 +14,9 @@
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj);
|
||||
|
||||
// common regressions
|
||||
// linear regression
|
||||
struct LinearSquareLoss {
|
||||
@@ -84,7 +87,9 @@ class RegLossObj : public ObjFunction {
|
||||
int iter,
|
||||
std::vector<bst_gpair> *out_gpair) override {
|
||||
CHECK_NE(info.labels.size(), 0) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.size(), info.labels.size()) << "labels are not correctly provided";
|
||||
CHECK_EQ(preds.size(), info.labels.size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.size() << ", label.size=" << info.labels.size();
|
||||
out_gpair->resize(preds.size());
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
@@ -95,7 +100,7 @@ class RegLossObj : public ObjFunction {
|
||||
float p = Loss::PredTransform(preds[i]);
|
||||
float w = info.GetWeight(i);
|
||||
if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
|
||||
if (Loss::CheckLabel(info.labels[i])) label_correct = false;
|
||||
if (!Loss::CheckLabel(info.labels[i])) label_correct = false;
|
||||
out_gpair->at(i) = bst_gpair(Loss::FirstOrderGradient(p, info.labels[i]) * w,
|
||||
Loss::SecondOrderGradient(p, info.labels[i]) * w);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user