xgboost/plugin/example/custom_obj.cc
Andy Adinets 72cd1517d6 Replaced std::vector with HostDeviceVector in MetaInfo and SparsePage. (#3446)
* Replaced std::vector with HostDeviceVector in MetaInfo and SparsePage.

- added distributions to HostDeviceVector
- using HostDeviceVector for labels, weights and base margings in MetaInfo
- using HostDeviceVector for offset and data in SparsePage
- other necessary refactoring

* Added const version of HostDeviceVector API calls.

- const versions added to calls that can trigger data transfers, e.g. DevicePointer()
- updated the code that uses HostDeviceVector
- objective functions now accept const HostDeviceVector<bst_float>& for predictions

* Updated src/linear/updater_gpu_coordinate.cu.

* Added read-only state for HostDeviceVector sync.

- this means no copies are performed if both host and devices access
  the HostDeviceVector read-only

* Fixed linter and test errors.

- updated the lz4 plugin
- added ConstDeviceSpan to HostDeviceVector
- using device % dh::NVisibleDevices() for the physical device number,
  e.g. in calls to cudaSetDevice()

* Fixed explicit template instantiation errors for HostDeviceVector.

- replaced HostDeviceVector<unsigned int> with HostDeviceVector<int>

* Fixed HostDeviceVector tests that require multiple GPUs.

- added a mock set device handler; when set, it is called instead of cudaSetDevice()
2018-08-30 14:28:47 +12:00

84 lines
2.9 KiB
C++

/*!
* Copyright 2015 by Contributors
* \file custom_metric.cc
* \brief This is an example to define plugin of xgboost.
* This plugin defines the additional metric function.
*/
#include <xgboost/base.h>
#include <dmlc/parameter.h>
#include <xgboost/objective.h>
namespace xgboost {
namespace obj {
// This is a helpful data structure to define parameters
// You do not have to use it.
// see http://dmlc-core.readthedocs.org/en/latest/parameter.html
// for introduction of this module.
struct MyLogisticParam : public dmlc::Parameter<MyLogisticParam> {
float scale_neg_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(MyLogisticParam) {
DMLC_DECLARE_FIELD(scale_neg_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of negative examples by this factor");
}
};
DMLC_REGISTER_PARAMETER(MyLogisticParam);
// Define a customized logistic regression objective in C++.
// Implement the interface.
class MyLogistic : public ObjFunction {
public:
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
out_gpair->Resize(preds.Size());
const std::vector<bst_float>& preds_h = preds.HostVector();
std::vector<GradientPair>& out_gpair_h = out_gpair->HostVector();
const std::vector<bst_float>& labels_h = info.labels_.HostVector();
for (size_t i = 0; i < preds_h.size(); ++i) {
bst_float w = info.GetWeight(i);
// scale the negative examples!
if (labels_h[i] == 0.0f) w *= param_.scale_neg_weight;
// logistic transformation
bst_float p = 1.0f / (1.0f + std::exp(-preds_h[i]));
// this is the gradient
bst_float grad = (p - labels_h[i]) * w;
// this is the second order gradient
bst_float hess = p * (1.0f - p) * w;
out_gpair_h.at(i) = GradientPair(grad, hess);
}
}
const char* DefaultEvalMetric() const override {
return "error";
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
// transform margin value to probability.
std::vector<bst_float> &preds = io_preds->HostVector();
for (size_t i = 0; i < preds.size(); ++i) {
preds[i] = 1.0f / (1.0f + std::exp(-preds[i]));
}
}
bst_float ProbToMargin(bst_float base_score) const override {
// transform probability to margin value
return -std::log(1.0f / base_score - 1.0f);
}
private:
MyLogisticParam param_;
};
// Finally register the objective function.
// After it succeeds you can try use xgboost with objective=mylogistic
XGBOOST_REGISTER_OBJECTIVE(MyLogistic, "mylogistic")
.describe("User defined logistic regression plugin")
.set_body([]() { return new MyLogistic(); });
} // namespace obj
} // namespace xgboost