Implement transform to reduce CPU/GPU code duplication. (#3643)

* Implement Transform class.
* Add tests for softmax.
* Use Transform in regression, softmax and hinge objectives, except for Cox.
* Mark old gpu objective functions deprecated.
* static_assert for softmax.
* Split up multi-gpu tests.
This commit is contained in:
trivialfis 2018-10-02 15:06:21 +13:00 committed by Rory Mitchell
parent 87aca8c244
commit d594b11f35
31 changed files with 1514 additions and 997 deletions

View File

@ -1,9 +1,11 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2018 by Contributors
* \file common.cc
* \brief Enable all kinds of global variables in common.
*/
#include <dmlc/thread_local.h>
#include "common.h"
#include "./random.h"
namespace xgboost {

View File

@ -11,7 +11,7 @@ int AllVisibleImpl::AllVisible() {
// When compiled with CUDA but running on CPU only device,
// cudaGetDeviceCount will fail.
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
} catch(const std::exception& e) {
} catch(const thrust::system::system_error& err) {
return 0;
}
return n_visgpus;

View File

@ -1,5 +1,5 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2018 by Contributors
* \file common.h
* \brief Common utilities
*/
@ -19,6 +19,13 @@
#if defined(__CUDACC__)
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#define WITH_CUDA() true
#else
#define WITH_CUDA() false
#endif
namespace dh {
@ -29,11 +36,11 @@ namespace dh {
#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__)
inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
int line) {
int line) {
if (code != cudaSuccess) {
throw thrust::system_error(code, thrust::cuda_category(),
std::string{file} + "(" + // NOLINT
std::to_string(line) + ")");
LOG(FATAL) << thrust::system_error(code, thrust::cuda_category(),
std::string{file} + ": " + // NOLINT
std::to_string(line)).what();
}
return code;
}
@ -70,13 +77,13 @@ inline std::string ToString(const T& data) {
*/
class Range {
public:
using DifferenceType = int64_t;
class Iterator {
friend class Range;
public:
using DifferenceType = int64_t;
XGBOOST_DEVICE int64_t operator*() const { return i_; }
XGBOOST_DEVICE DifferenceType operator*() const { return i_; }
XGBOOST_DEVICE const Iterator &operator++() {
i_ += step_;
return *this;
@ -97,8 +104,8 @@ class Range {
XGBOOST_DEVICE void Step(DifferenceType s) { step_ = s; }
protected:
XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {}
XGBOOST_DEVICE explicit Iterator(int64_t start, int step) :
XGBOOST_DEVICE explicit Iterator(DifferenceType start) : i_(start) {}
XGBOOST_DEVICE explicit Iterator(DifferenceType start, DifferenceType step) :
i_{start}, step_{step} {}
public:
@ -109,9 +116,10 @@ class Range {
XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT
XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT
XGBOOST_DEVICE Range(int64_t begin, int64_t end)
XGBOOST_DEVICE Range(DifferenceType begin, DifferenceType end)
: begin_(begin), end_(end) {}
XGBOOST_DEVICE Range(int64_t begin, int64_t end, Iterator::DifferenceType step)
XGBOOST_DEVICE Range(DifferenceType begin, DifferenceType end,
DifferenceType step)
: begin_(begin, step), end_(end) {}
XGBOOST_DEVICE bool operator==(const Range& other) const {
@ -121,9 +129,7 @@ class Range {
return !(*this == other);
}
XGBOOST_DEVICE void Step(Iterator::DifferenceType s) { begin_.Step(s); }
XGBOOST_DEVICE Iterator::DifferenceType GetStep() const { return begin_.step_; }
XGBOOST_DEVICE void Step(DifferenceType s) { begin_.Step(s); }
private:
Iterator begin_;

View File

@ -9,6 +9,7 @@
#include <xgboost/logging.h>
#include "common.h"
#include "span.h"
#include <algorithm>
#include <chrono>
@ -955,7 +956,7 @@ class SaveCudaContext {
// cudaGetDevice will fail.
try {
safe_cuda(cudaGetDevice(&saved_device_));
} catch (thrust::system::system_error & err) {
} catch (const thrust::system::system_error & err) {
saved_device_ = -1;
}
func();
@ -1035,4 +1036,22 @@ ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
};
return std::accumulate(sums.begin(), sums.end(), ReduceT());
}
template <typename T,
typename IndexT = typename xgboost::common::Span<T>::index_type>
xgboost::common::Span<T> ToSpan(
thrust::device_vector<T>& vec,
IndexT offset = 0,
IndexT size = -1) {
size = size == -1 ? vec.size() : size;
CHECK_LE(offset + size, vec.size());
return {vec.data().get() + offset, static_cast<IndexT>(size)};
}
template <typename T>
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
size_t offset, size_t size) {
using IndexT = typename xgboost::common::Span<T>::index_type;
return ToSpan(vec, static_cast<IndexT>(offset), static_cast<IndexT>(size));
}
} // namespace dh

View File

@ -116,6 +116,7 @@ struct HostDeviceVectorImpl {
int ndevices = vec_->distribution_.devices_.Size();
start_ = vec_->distribution_.ShardStart(new_size, index_);
proper_size_ = vec_->distribution_.ShardProperSize(new_size, index_);
// The size on this device.
size_t size_d = vec_->distribution_.ShardSize(new_size, index_);
SetDevice();
data_.resize(size_d);
@ -230,7 +231,7 @@ struct HostDeviceVectorImpl {
CHECK(devices.Contains(device));
LazySyncDevice(device, GPUAccess::kWrite);
return {shards_[devices.Index(device)].data_.data().get(),
static_cast<typename common::Span<T>::index_type>(DeviceSize(device))};
static_cast<typename common::Span<T>::index_type>(DeviceSize(device))};
}
common::Span<const T> ConstDeviceSpan(int device) {
@ -238,7 +239,7 @@ struct HostDeviceVectorImpl {
CHECK(devices.Contains(device));
LazySyncDevice(device, GPUAccess::kRead);
return {shards_[devices.Index(device)].data_.data().get(),
static_cast<typename common::Span<const T>::index_type>(DeviceSize(device))};
static_cast<typename common::Span<const T>::index_type>(DeviceSize(device))};
}
size_t DeviceSize(int device) {
@ -289,7 +290,6 @@ struct HostDeviceVectorImpl {
data_h_.size() * sizeof(T),
cudaMemcpyHostToDevice));
} else {
//
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
}
}
@ -304,14 +304,20 @@ struct HostDeviceVectorImpl {
void Copy(HostDeviceVectorImpl<T>* other) {
CHECK_EQ(Size(), other->Size());
// Data is on host.
if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) {
std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin());
} else {
CHECK(distribution_ == other->distribution_);
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
shard.Copy(&other->shards_[i]);
});
return;
}
// Data is on device;
if (distribution_ != other->distribution_) {
distribution_ = GPUDistribution();
Reshard(other->Distribution());
size_d_ = other->size_d_;
}
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
shard.Copy(&other->shards_[i]);
});
}
void Copy(const std::vector<T>& other) {

View File

@ -111,8 +111,11 @@ class GPUDistribution {
}
friend bool operator==(const GPUDistribution& a, const GPUDistribution& b) {
return a.devices_ == b.devices_ && a.granularity_ == b.granularity_ &&
a.overlap_ == b.overlap_ && a.offsets_ == b.offsets_;
bool const res = a.devices_ == b.devices_ &&
a.granularity_ == b.granularity_ &&
a.overlap_ == b.overlap_ &&
a.offsets_ == b.offsets_;
return res;
}
friend bool operator!=(const GPUDistribution& a, const GPUDistribution& b) {

View File

@ -11,6 +11,7 @@
#include <vector>
#include <cmath>
#include <algorithm>
#include <utility>
#include "avx_helpers.h"
namespace xgboost {
@ -29,22 +30,31 @@ inline avx::Float8 Sigmoid(avx::Float8 x) {
}
/*!
* \brief do inplace softmax transformaton on p_rec
* \param p_rec the input/output vector of the values.
* \brief Do inplace softmax transformaton on start to end
*
* \tparam Iterator Input iterator type
*
* \param start Start iterator of input
* \param end end iterator of input
*/
inline void Softmax(std::vector<float>* p_rec) {
std::vector<float> &rec = *p_rec;
float wmax = rec[0];
for (size_t i = 1; i < rec.size(); ++i) {
wmax = std::max(rec[i], wmax);
template <typename Iterator>
XGBOOST_DEVICE inline void Softmax(Iterator start, Iterator end) {
static_assert(std::is_same<bst_float,
typename std::remove_reference<
decltype(std::declval<Iterator>().operator*())>::type
>::value,
"Values should be of type bst_float");
bst_float wmax = *start;
for (Iterator i = start+1; i != end; ++i) {
wmax = fmaxf(*i, wmax);
}
double wsum = 0.0f;
for (float & elem : rec) {
elem = std::exp(elem - wmax);
wsum += elem;
for (Iterator i = start; i != end; ++i) {
*i = expf(*i - wmax);
wsum += *i;
}
for (float & elem : rec) {
elem /= static_cast<float>(wsum);
for (Iterator i = start; i != end; ++i) {
*i /= static_cast<float>(wsum);
}
}
@ -56,7 +66,7 @@ inline void Softmax(std::vector<float>* p_rec) {
* \tparam Iterator The type of the iterator.
*/
template<typename Iterator>
inline Iterator FindMaxIndex(Iterator begin, Iterator end) {
XGBOOST_DEVICE inline Iterator FindMaxIndex(Iterator begin, Iterator end) {
Iterator maxit = begin;
for (Iterator it = begin; it != end; ++it) {
if (*it > *maxit) maxit = it;

View File

@ -49,7 +49,7 @@
*
* https://github.com/Microsoft/GSL/pull/664
*
* FIXME: Group these MSVC workarounds into a manageable place.
* TODO(trivialfis): Group these MSVC workarounds into a manageable place.
*/
#if defined(_MSC_VER) && _MSC_VER < 1910
@ -68,7 +68,7 @@ namespace xgboost {
namespace common {
// Usual logging facility is not available inside device code.
// FIXME: Make dmlc check more generic.
// TODO(trivialfis): Make dmlc check more generic.
#define KERNEL_CHECK(cond) \
do { \
if (!(cond)) { \
@ -104,11 +104,11 @@ constexpr detail::ptrdiff_t dynamic_extent = -1; // NOLINT
enum class byte : unsigned char {}; // NOLINT
namespace detail {
template <class ElementType, detail::ptrdiff_t Extent = dynamic_extent>
template <class ElementType, detail::ptrdiff_t Extent>
class Span;
namespace detail {
template <typename SpanType, bool IsConst>
class SpanIterator {
using ElementType = typename SpanType::element_type;

204
src/common/transform.h Normal file
View File

@ -0,0 +1,204 @@
/*!
* Copyright 2018 XGBoost contributors
*/
#ifndef XGBOOST_COMMON_TRANSFORM_H_
#define XGBOOST_COMMON_TRANSFORM_H_
#include <dmlc/omp.h>
#include <xgboost/data.h>
#include <vector>
#include <type_traits> // enable_if
#include "host_device_vector.h"
#include "common.h"
#include "span.h"
#if defined (__CUDACC__)
#include "device_helpers.cuh"
#endif
namespace xgboost {
namespace common {
constexpr size_t kBlockThreads = 256;
namespace detail {
#if defined(__CUDACC__)
template <typename Functor, typename... SpanType>
__global__ void LaunchCUDAKernel(Functor _func, Range _range,
SpanType... _spans) {
for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) {
_func(i, _spans...);
}
}
#endif
} // namespace detail
/*! \brief Do Transformation on HostDeviceVectors.
*
* \tparam CompiledWithCuda A bool parameter used to distinguish compilation
* trajectories, users do not need to use it.
*
* Note: Using Transform is a VERY tricky thing to do. Transform uses template
* argument to duplicate itself into two different types, one for CPU,
* another for CUDA. The trick is not without its flaw:
*
* If you use it in a function that can be compiled by both nvcc and host
* compiler, the behaviour is un-defined! Because your function is NOT
* duplicated by `CompiledWithCuda`. At link time, cuda compiler resolution
* will merge functions with same signature.
*/
template <bool CompiledWithCuda = WITH_CUDA()>
class Transform {
private:
template <typename Functor>
struct Evaluator {
public:
Evaluator(Functor func, Range range, GPUSet devices, bool reshard) :
func_(func), range_{std::move(range)},
distribution_{std::move(GPUDistribution::Block(devices))},
reshard_{reshard} {}
Evaluator(Functor func, Range range, GPUDistribution dist,
bool reshard) :
func_(func), range_{std::move(range)}, distribution_{std::move(dist)},
reshard_{reshard} {}
/*!
* \brief Evaluate the functor with input pointers to HostDeviceVector.
*
* \tparam HDV... HostDeviceVectors type.
* \param vectors Pointers to HostDeviceVector.
*/
template <typename... HDV>
void Eval(HDV... vectors) const {
bool on_device = !distribution_.IsEmpty();
if (on_device) {
LaunchCUDA(func_, vectors...);
} else {
LaunchCPU(func_, vectors...);
}
}
private:
// CUDA UnpackHDV
template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec, int _device) const {
return _vec->DeviceSpan(_device);
}
template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec, int _device) const {
return _vec->ConstDeviceSpan(_device);
}
// CPU UnpackHDV
template <typename T>
Span<T> UnpackHDV(HostDeviceVector<T>* _vec) const {
return Span<T> {_vec->HostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}
template <typename T>
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec) const {
return Span<T const> {_vec->ConstHostPointer(),
static_cast<typename Span<T>::index_type>(_vec->Size())};
}
// Recursive unpack for Reshard.
template <typename T>
void UnpackReshard(GPUDistribution dist, const HostDeviceVector<T>* vector) const {
vector->Reshard(dist);
}
template <typename Head, typename... Rest>
void UnpackReshard(GPUDistribution dist,
const HostDeviceVector<Head>* _vector,
const HostDeviceVector<Rest>*... _vectors) const {
_vector->Reshard(dist);
UnpackReshard(dist, _vectors...);
}
#if defined(__CUDACC__)
template <typename std::enable_if<CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, HDV*... _vectors) const {
if (reshard_)
UnpackReshard(distribution_, _vectors...);
GPUSet devices = distribution_.Devices();
size_t range_size = *range_.end() - *range_.begin();
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
for (omp_ulong i = 0; i < devices.Size(); ++i) {
int d = devices.Index(i);
// Ignore other attributes of GPUDistribution for spliting index.
size_t shard_size =
GPUDistribution::Block(devices).ShardSize(range_size, d);
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
dh::safe_cuda(cudaSetDevice(d));
const int GRID_SIZE =
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
_func, shard_range, UnpackHDV(_vectors, d)...);
dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}
}
#else
/*! \brief Dummy funtion defined when compiling for CPU. */
template <typename std::enable_if<!CompiledWithCuda>::type* = nullptr,
typename... HDV>
void LaunchCUDA(Functor _func, HDV*... _vectors) const {
LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA();
}
#endif
template <typename... HDV>
void LaunchCPU(Functor func, HDV*... vectors) const {
auto end = *(range_.end());
#pragma omp parallel for schedule(static)
for (omp_ulong idx = 0; idx < end; ++idx) {
func(idx, UnpackHDV(vectors)...);
}
}
private:
/*! \brief Callable object. */
Functor func_;
/*! \brief Range object specifying parallel threads index range. */
Range range_;
/*! \brief Whether resharding for vectors is required. */
bool reshard_;
GPUDistribution distribution_;
};
public:
/*!
* \brief Initialize a Transform object.
*
* \tparam Functor A callable object type.
* \return A Evaluator having one method Eval.
*
* \param func A callable object, accepting a size_t thread index,
* followed by a set of Span classes.
* \param range Range object specifying parallel threads index range.
* \param devices GPUSet specifying GPUs to use, when compiling for CPU,
* this should be GPUSet::Empty().
* \param reshard Whether Reshard for HostDeviceVector is needed.
*/
template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range,
GPUSet const devices,
bool const reshard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(devices), reshard};
}
template <typename Functor>
static Evaluator<Functor> Init(Functor func, Range const range,
GPUDistribution const dist,
bool const reshard = true) {
return Evaluator<Functor> {func, std::move(range), std::move(dist), reshard};
}
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_TRANSFORM_H_

View File

@ -1,73 +1,18 @@
/*!
* Copyright 2018 by Contributors
* \file hinge.cc
* \brief Provides an implementation of the hinge loss function
* \author Henry Gouk
* Copyright 2018 XGBoost contributors
*/
#include <xgboost/objective.h>
#include "../common/math.h"
// Dummy file to keep the CUDA conditional compile trick.
#include <dmlc/registry.h>
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(hinge);
class HingeObj : public ObjFunction {
public:
HingeObj() = default;
void Configure(
const std::vector<std::pair<std::string, std::string> > &args) override {
// This objective does not take any parameters
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size()
<< ", label.size=" << info.labels_.Size();
const auto& preds_h = preds.HostVector();
const auto& labels_h = info.labels_.HostVector();
const auto& weights_h = info.weights_.HostVector();
out_gpair->Resize(preds_h.size());
auto& gpair = out_gpair->HostVector();
for (size_t i = 0; i < preds_h.size(); ++i) {
auto y = labels_h[i] * 2.0 - 1.0;
bst_float p = preds_h[i];
bst_float w = weights_h.size() > 0 ? weights_h[i] : 1.0f;
bst_float g, h;
if (p * y < 1.0) {
g = -y * w;
h = w;
} else {
g = 0.0;
h = std::numeric_limits<bst_float>::min();
}
gpair[i] = GradientPair(g, h);
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
for (auto& p : preds) {
p = p > 0.0 ? 1.0 : 0.0;
}
}
const char* DefaultEvalMetric() const override {
return "error";
}
};
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
DMLC_REGISTRY_FILE_TAG(hinge_obj);
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#include "hinge.cu"
#endif

109
src/objective/hinge.cu Normal file
View File

@ -0,0 +1,109 @@
/*!
* Copyright 2018 by Contributors
* \file hinge.cc
* \brief Provides an implementation of the hinge loss function
* \author Henry Gouk
*/
#include <xgboost/objective.h>
#include "../common/math.h"
#include "../common/transform.h"
#include "../common/common.h"
#include "../common/span.h"
#include "../common/host_device_vector.h"
namespace xgboost {
namespace obj {
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);
#endif
struct HingeObjParam : public dmlc::Parameter<HingeObjParam> {
int n_gpus;
int gpu_id;
DMLC_DECLARE_PARAMETER(HingeObjParam) {
DMLC_DECLARE_FIELD(n_gpus).set_default(0).set_lower_bound(0)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
class HingeObj : public ObjFunction {
public:
HingeObj() = default;
void Configure(
const std::vector<std::pair<std::string, std::string> > &args) override {
param_.InitAllowUnknown(args);
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size()
<< ", label.size=" << info.labels_.Size();
const bool is_null_weight = info.weights_.Size() == 0;
const size_t ndata = preds.Size();
out_gpair->Resize(ndata);
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx] * 2.0 - 1.0;
bst_float g, h;
if (p * y < 1.0) {
g = -y * w;
h = w;
} else {
g = 0.0;
h = std::numeric_limits<bst_float>::min();
}
_out_gpair[_idx] = GradientPair(g, h);
},
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
},
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, devices_)
.Eval(io_preds);
}
const char* DefaultEvalMetric() const override {
return "error";
}
private:
GPUSet devices_;
HostDeviceVector<int> label_correct_;
HingeObjParam param_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(HingeObjParam);
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
.describe("Hinge loss. Expects labels to be in [0,1f]")
.set_body([]() { return new HingeObj(); });
} // namespace obj
} // namespace xgboost

View File

@ -1,141 +1,18 @@
/*!
* Copyright 2015 by Contributors
* \file multi_class.cc
* \brief Definition of multi-class classification objectives.
* \author Tianqi Chen
* Copyright 2018 XGBoost contributors
*/
#include <dmlc/omp.h>
#include <dmlc/parameter.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "../common/math.h"
// Dummy file to keep the CUDA conditional compile trick.
#include <dmlc/registry.h>
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(multiclass_obj);
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
int num_class;
// declare parameters
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
DMLC_DECLARE_FIELD(num_class).set_lower_bound(1)
.describe("Number of output class in the multi-class classification.");
}
};
class SoftmaxMultiClassObj : public ObjFunction {
public:
explicit SoftmaxMultiClassObj(bool output_prob)
: output_prob_(output_prob) {
}
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size()))
<< "SoftmaxMultiClassObj: label size and pred size does not match";
const std::vector<bst_float>& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size());
std::vector<GradientPair>& gpair = out_gpair->HostVector();
const int nclass = param_.num_class;
const auto ndata = static_cast<omp_ulong>(preds_h.size() / nclass);
const auto& labels = info.labels_.HostVector();
int label_error = 0;
#pragma omp parallel
{
std::vector<bst_float> rec(nclass);
#pragma omp for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) {
for (int k = 0; k < nclass; ++k) {
rec[k] = preds_h[i * nclass + k];
}
common::Softmax(&rec);
auto label = static_cast<int>(labels[i]);
if (label < 0 || label >= nclass) {
label_error = label; label = 0;
}
const bst_float wt = info.GetWeight(i);
for (int k = 0; k < nclass; ++k) {
bst_float p = rec[k];
const float eps = 1e-16f;
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps);
if (label == k) {
gpair[i * nclass + k] = GradientPair((p - 1.0f) * wt, h);
} else {
gpair[i * nclass + k] = GradientPair(p* wt, h);
}
}
}
}
CHECK(label_error >= 0 && label_error < nclass)
<< "SoftmaxMultiClassObj: label must be in [0, num_class),"
<< " num_class=" << nclass
<< " but found " << label_error << " in label.";
}
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
this->Transform(io_preds, output_prob_);
}
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
this->Transform(io_preds, true);
}
const char* DefaultEvalMetric() const override {
return "merror";
}
private:
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
std::vector<bst_float> &preds = io_preds->HostVector();
std::vector<bst_float> tmp;
const int nclass = param_.num_class;
const auto ndata = static_cast<omp_ulong>(preds.size() / nclass);
if (!prob) tmp.resize(ndata);
#pragma omp parallel
{
std::vector<bst_float> rec(nclass);
#pragma omp for schedule(static)
for (omp_ulong j = 0; j < ndata; ++j) {
for (int k = 0; k < nclass; ++k) {
rec[k] = preds[j * nclass + k];
}
if (!prob) {
tmp[j] = static_cast<bst_float>(
common::FindMaxIndex(rec.begin(), rec.end()) - rec.begin());
} else {
common::Softmax(&rec);
for (int k = 0; k < nclass; ++k) {
preds[j * nclass + k] = rec[k];
}
}
}
}
if (!prob) preds = tmp;
}
// output probability
bool output_prob_;
// parameter
SoftmaxMultiClassParam param_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax")
.describe("Softmax for multi-class classification, output class index.")
.set_body([]() { return new SoftmaxMultiClassObj(false); });
XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob")
.describe("Softmax for multi-class classification, output probability distribution.")
.set_body([]() { return new SoftmaxMultiClassObj(true); });
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#include "multiclass_obj.cu"
#endif

View File

@ -0,0 +1,195 @@
/*!
* Copyright 2015-2018 by Contributors
* \file multi_class.cc
* \brief Definition of multi-class classification objectives.
* \author Tianqi Chen
*/
#include <dmlc/omp.h>
#include <dmlc/parameter.h>
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <algorithm>
#include <limits>
#include <utility>
#include "../common/math.h"
#include "../common/transform.h"
namespace xgboost {
namespace obj {
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(multiclass_obj_gpu);
#endif
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
int num_class;
int n_gpus;
int gpu_id;
// declare parameters
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
DMLC_DECLARE_FIELD(num_class).set_lower_bound(1)
.describe("Number of output class in the multi-class classification.");
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
// TODO(trivialfis): Currently the resharding in softmax is less than ideal
// due to repeated copying data between CPU and GPUs. Maybe we just use single
// GPU?
class SoftmaxMultiClassObj : public ObjFunction {
public:
explicit SoftmaxMultiClassObj(bool output_prob)
: output_prob_(output_prob) {
}
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo& info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size()))
<< "SoftmaxMultiClassObj: label size and pred size does not match";
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
// clear out device memory;
out_gpair->Reshard(GPUSet::Empty());
preds.Reshard(GPUSet::Empty());
out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass));
info.labels_.Reshard(GPUDistribution::Block(devices_));
info.weights_.Reshard(GPUDistribution::Block(devices_));
preds.Reshard(GPUDistribution::Granular(devices_, nclass));
label_correct_.Reshard(GPUDistribution::Block(devices_));
out_gpair->Resize(preds.Size());
label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t idx,
common::Span<GradientPair> gpair,
common::Span<bst_float const> labels,
common::Span<bst_float const> preds,
common::Span<bst_float const> weights,
common::Span<int> _label_correct) {
common::Span<bst_float const> point = preds.subspan(idx * nclass, nclass);
// Part of Softmax function
bst_float wmax = std::numeric_limits<bst_float>::min();
for (auto const i : point) { wmax = fmaxf(i, wmax); }
double wsum = 0.0f;
for (auto const i : point) { wsum += expf(i - wmax); }
auto label = labels[idx];
if (label < 0 || label >= nclass) {
_label_correct[0] = 0;
label = 0;
}
bst_float wt = is_null_weight ? 1.0f : weights[idx];
for (int k = 0; k < nclass; ++k) {
// Computation duplicated to avoid creating a cache.
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, kRtEps);
p = label == k ? p - 1.0f : p;
gpair[idx * nclass + k] = GradientPair(p * wt, h);
}
}, common::Range{0, ndata}, devices_, false)
.Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_);
out_gpair->Reshard(GPUSet::Empty());
out_gpair->Reshard(GPUDistribution::Block(devices_));
preds.Reshard(GPUSet::Empty());
preds.Reshard(GPUDistribution::Block(devices_));
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag != 1) {
LOG(FATAL) << "SoftmaxMultiClassObj: label must be in [0, num_class).";
}
}
}
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
this->Transform(io_preds, output_prob_);
}
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
this->Transform(io_preds, true);
}
const char* DefaultEvalMetric() const override {
return "merror";
}
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
const int nclass = param_.num_class;
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
max_preds_.Resize(ndata);
io_preds->Reshard(GPUSet::Empty()); // clear out device memory
if (prob) {
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
common::Span<bst_float> point =
_preds.subspan(_idx * nclass, nclass);
common::Softmax(point.begin(), point.end());
},
common::Range{0, ndata}, GPUDistribution::Granular(devices_, nclass))
.Eval(io_preds);
} else {
io_preds->Reshard(GPUDistribution::Granular(devices_, nclass));
max_preds_.Reshard(GPUDistribution::Block(devices_));
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<const bst_float> _preds,
common::Span<bst_float> _max_preds) {
common::Span<const bst_float> point =
_preds.subspan(_idx * nclass, nclass);
_max_preds[_idx] =
common::FindMaxIndex(point.cbegin(),
point.cend()) - point.cbegin();
},
common::Range{0, ndata}, devices_, false)
.Eval(io_preds, &max_preds_);
}
if (!prob) {
io_preds->Resize(max_preds_.Size());
io_preds->Copy(max_preds_);
}
io_preds->Reshard(GPUSet::Empty()); // clear out device memory
io_preds->Reshard(GPUDistribution::Block(devices_));
}
private:
// output probability
bool output_prob_;
// parameter
SoftmaxMultiClassParam param_;
GPUSet devices_;
// Cache for max_preds
HostDeviceVector<bst_float> max_preds_;
HostDeviceVector<int> label_correct_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax")
.describe("Softmax for multi-class classification, output class index.")
.set_body([]() { return new SoftmaxMultiClassObj(false); });
XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob")
.describe("Softmax for multi-class classification, output probability distribution.")
.set_body([]() { return new SoftmaxMultiClassObj(true); });
} // namespace obj
} // namespace xgboost

View File

@ -30,12 +30,15 @@ ObjFunction* ObjFunction::Create(const std::string& name) {
namespace xgboost {
namespace obj {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(regression_obj);
#ifdef XGBOOST_USE_CUDA
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
#endif
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
#else
DMLC_REGISTRY_LINK_TAG(regression_obj);
DMLC_REGISTRY_LINK_TAG(hinge_obj);
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
#endif
DMLC_REGISTRY_LINK_TAG(rank_obj);
DMLC_REGISTRY_LINK_TAG(hinge);
} // namespace obj
} // namespace xgboost

View File

@ -1,426 +1,18 @@
/*!
* Copyright 2015 by Contributors
* \file regression_obj.cc
* \brief Definition of single-value regression and classification objectives.
* \author Tianqi Chen, Kailong Chen
* Copyright 2018 XGBoost contributors
*/
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <vector>
#include <algorithm>
#include <utility>
#include "../common/math.h"
#include "../common/avx_helpers.h"
#include "./regression_loss.h"
// Dummy file to keep the CUDA conditional compile trick.
#include <dmlc/registry.h>
namespace xgboost {
namespace obj {
DMLC_REGISTRY_FILE_TAG(regression_obj);
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
float scale_pos_weight;
// declare parameters
DMLC_DECLARE_PARAMETER(RegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor");
}
};
// regression loss function
template <typename Loss>
class RegLossObj : public ObjFunction {
public:
RegLossObj() = default;
void Configure(
const std::vector<std::pair<std::string, std::string> > &args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
int iter, HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size()
<< ", label.size=" << info.labels_.Size();
const auto& preds_h = preds.HostVector();
const auto& labels = info.labels_.HostVector();
const auto& weights = info.weights_.HostVector();
this->LazyCheckLabels(labels);
out_gpair->Resize(preds_h.size());
auto& gpair = out_gpair->HostVector();
const auto n = static_cast<omp_ulong>(preds_h.size());
auto gpair_ptr = out_gpair->HostPointer();
avx::Float8 scale(param_.scale_pos_weight);
const omp_ulong remainder = n % 8;
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < n - remainder; i += 8) {
avx::Float8 y(&labels[i]);
avx::Float8 p = Loss::PredTransform(avx::Float8(&preds_h[i]));
avx::Float8 w = weights.empty() ? avx::Float8(1.0f)
: avx::Float8(&weights[i]);
// Adjust weight
w += y * (scale * w - w);
avx::Float8 grad = Loss::FirstOrderGradient(p, y);
avx::Float8 hess = Loss::SecondOrderGradient(p, y);
avx::StoreGpair(gpair_ptr + i, grad * w, hess * w);
}
for (omp_ulong i = n - remainder; i < n; ++i) {
auto y = labels[i];
bst_float p = Loss::PredTransform(preds_h[i]);
bst_float w = info.GetWeight(i);
w += y * ((param_.scale_pos_weight * w) - w);
gpair[i] = GradientPair(Loss::FirstOrderGradient(p, y) * w,
Loss::SecondOrderGradient(p, y) * w);
}
}
const char *DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
const auto ndata = static_cast<bst_omp_uint>(preds.size());
#pragma omp parallel for schedule(static)
for (bst_omp_uint j = 0; j < ndata; ++j) {
preds[j] = Loss::PredTransform(preds[j]);
}
}
bst_float ProbToMargin(bst_float base_score) const override {
return Loss::ProbToMargin(base_score);
}
protected:
void LazyCheckLabels(const std::vector<float> &labels) {
if (labels_checked_) return;
for (auto &y : labels) {
CHECK(Loss::CheckLabel(y)) << Loss::LabelErrorMsg();
}
labels_checked_ = true;
}
RegLossParam param_;
bool labels_checked_{false};
};
// register the objective functions
DMLC_REGISTER_PARAMETER(RegLossParam);
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
.describe("Linear regression.")
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic")
.describe("Logistic regression for probability regression task.")
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic")
.describe("Logistic regression for binary classification task.")
.set_body([]() { return new RegLossObj<LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
.describe("Logistic regression for classification, output score before logistic transformation")
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
// declare parameter
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
float max_delta_step;
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
.describe("Maximum delta step we allow each weight estimation to be." \
" This parameter is required for possion regression.");
}
};
// poisson regression for count
class PoissonRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds.Size());
auto& gpair = out_gpair->HostVector();
const auto& labels = info.labels_.HostVector();
// check if label in range
bool label_correct = true;
// start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds_h[i];
bst_float w = info.GetWeight(i);
bst_float y = labels[i];
if (y >= 0.0f) {
gpair[i] = GradientPair((std::exp(p) - y) * w,
std::exp(p + param_.max_delta_step) * w);
} else {
label_correct = false;
}
}
CHECK(label_correct) << "PoissonRegression: label must be nonnegative";
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "poisson-nloglik";
}
private:
PoissonRegressionParam param_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(PoissonRegressionParam);
XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
.describe("Possion regression for count data.")
.set_body([]() { return new PoissonRegression(); });
// cox regression for survival data (negative values mean they are censored)
class CoxRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size());
auto& gpair = out_gpair->HostVector();
const std::vector<size_t> &label_order = info.LabelAbsSort();
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
// pre-compute a sum
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
for (omp_ulong i = 0; i < ndata; ++i) {
exp_p_sum += std::exp(preds_h[label_order[i]]);
}
// start calculating grad and hess
const auto& labels = info.labels_.HostVector();
double r_k = 0;
double s_k = 0;
double last_exp_p = 0.0;
double last_abs_y = 0.0;
double accumulated_sum = 0;
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
const size_t ind = label_order[i];
const double p = preds_h[ind];
const double exp_p = std::exp(p);
const double w = info.GetWeight(ind);
const double y = labels[ind];
const double abs_y = std::abs(y);
// only update the denominator after we move forward in time (labels are sorted)
// this is Breslow's method for ties
accumulated_sum += last_exp_p;
if (last_abs_y < abs_y) {
exp_p_sum -= accumulated_sum;
accumulated_sum = 0;
} else {
CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " <<
"MetaInfo::LabelArgsort failed!";
}
if (y > 0) {
r_k += 1.0/exp_p_sum;
s_k += 1.0/(exp_p_sum*exp_p_sum);
}
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
const double hess = exp_p*r_k - exp_p*exp_p * s_k;
gpair.at(ind) = GradientPair(grad * w, hess * w);
last_abs_y = abs_y;
last_exp_p = exp_p;
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "cox-nloglik";
}
};
// register the objective function
XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
.describe("Cox regression for censored survival data (negative labels are considered censored).")
.set_body([]() { return new CoxRegression(); });
// gamma regression
class GammaRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size());
auto& gpair = out_gpair->HostVector();
const auto& labels = info.labels_.HostVector();
// check if label in range
bool label_correct = true;
// start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds_h[i];
bst_float w = info.GetWeight(i);
bst_float y = labels[i];
if (y >= 0.0f) {
gpair[i] = GradientPair((1 - y / std::exp(p)) * w, y / std::exp(p) * w);
} else {
label_correct = false;
}
}
CHECK(label_correct) << "GammaRegression: label must be positive";
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "gamma-nloglik";
}
};
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
.describe("Gamma regression for severity data.")
.set_body([]() { return new GammaRegression(); });
// declare parameter
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
float tweedie_variance_power;
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
.describe("Tweedie variance power. Must be between in range [1, 2).");
}
};
// tweedie regression
class TweedieRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds.Size());
auto& gpair = out_gpair->HostVector();
const auto& labels = info.labels_.HostVector();
// check if label in range
bool label_correct = true;
// start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds.Size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds_h[i];
bst_float w = info.GetWeight(i);
bst_float y = labels[i];
float rho = param_.tweedie_variance_power;
if (y >= 0.0f) {
bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p);
bst_float hess = -y * (1 - rho) * \
std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p);
gpair[i] = GradientPair(grad * w, hess * w);
} else {
label_correct = false;
}
}
CHECK(label_correct) << "TweedieRegression: label must be nonnegative";
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
std::ostringstream os;
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
std::string metric = os.str();
return metric.c_str();
}
private:
TweedieRegressionParam param_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(TweedieRegressionParam);
XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
.describe("Tweedie regression for insurance data.")
.set_body([]() { return new TweedieRegression(); });
} // namespace obj
} // namespace xgboost
#ifndef XGBOOST_USE_CUDA
#include "regression_obj.cu"
#endif

View File

@ -0,0 +1,560 @@
/*!
* Copyright 2015-2018 by Contributors
* \file regression_obj.cu
* \brief Definition of single-value regression and classification objectives.
* \author Tianqi Chen, Kailong Chen
*/
#include <dmlc/omp.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <cmath>
#include <memory>
#include <vector>
#include "../common/span.h"
#include "../common/transform.h"
#include "../common/common.h"
#include "../common/host_device_vector.h"
#include "./regression_loss.h"
namespace xgboost {
namespace obj {
#if defined(XGBOOST_USE_CUDA)
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
#endif
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
float scale_pos_weight;
int n_gpus;
int gpu_id;
// declare parameters
DMLC_DECLARE_PARAMETER(RegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor");
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
template<typename Loss>
class RegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
public:
RegLossObj() = default;
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
size_t ndata = preds.Size();
out_gpair->Resize(ndata);
label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0;
auto scale_pos_weight = param_.scale_pos_weight;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = Loss::PredTransform(_preds[_idx]);
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float label = _labels[_idx];
if (label == 1.0f) {
w *= scale_pos_weight;
}
if (!Loss::CheckLabel(label)) {
// If there is an incorrect label, the host code will know.
_label_correct[0] = 0;
}
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
Loss::SecondOrderGradient(p, label) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << Loss::LabelErrorMsg();
}
}
}
public:
const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(HostDeviceVector<float> *io_preds) override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
devices_).Eval(io_preds);
}
float ProbToMargin(float base_score) const override {
return Loss::ProbToMargin(base_score);
}
protected:
RegLossParam param_;
GPUSet devices_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(RegLossParam);
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
.describe("Linear regression.")
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic")
.describe("Logistic regression for probability regression task.")
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic")
.describe("Logistic regression for binary classification task.")
.set_body([]() { return new RegLossObj<LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
.describe("Logistic regression for classification, output score "
"before logistic transformation.")
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
// Deprecated GPU functions
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
.describe("Deprecated. Linear regression (computed on GPU).")
.set_body([]() {
LOG(WARNING) << "gpu:reg:linear is now deprecated, use reg:linear instead.";
return new RegLossObj<LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic")
.describe("Deprecated. Logistic regression for probability regression task (computed on GPU).")
.set_body([]() {
LOG(WARNING) << "gpu:reg:logistic is now deprecated, use reg:logistic instead.";
return new RegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic")
.describe("Deprecated. Logistic regression for binary classification task (computed on GPU).")
.set_body([]() {
LOG(WARNING) << "gpu:binary:logistic is now deprecated, use binary:logistic instead.";
return new RegLossObj<LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw")
.describe("Deprecated. Logistic regression for classification, output score "
"before logistic transformation (computed on GPU)")
.set_body([]() {
LOG(WARNING) << "gpu:binary:logitraw is now deprecated, use binary:logitraw instead.";
return new RegLossObj<LogisticRaw>(); });
// End deprecated
// declare parameter
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
float max_delta_step;
int n_gpus;
int gpu_id;
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
.describe("Maximum delta step we allow each weight estimation to be." \
" This parameter is required for possion regression.");
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
// poisson regression for count
class PoissonRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
size_t ndata = preds.Size();
out_gpair->Resize(ndata);
label_correct_.Fill(1);
bool is_null_weight = info.weights_.Size() == 0;
bst_float max_delta_step = param_.max_delta_step;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx];
if (y < 0.0f) {
_label_correct[0] = 0;
}
_out_gpair[_idx] = GradientPair{(expf(p) - y) * w,
expf(p + max_delta_step) * w};
},
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << "PoissonRegression: label must be nonnegative";
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);
},
common::Range{0, static_cast<int64_t>(io_preds->Size())}, devices_)
.Eval(io_preds);
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "poisson-nloglik";
}
private:
GPUSet devices_;
PoissonRegressionParam param_;
HostDeviceVector<int> label_correct_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(PoissonRegressionParam);
XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
.describe("Possion regression for count data.")
.set_body([]() { return new PoissonRegression(); });
// cox regression for survival data (negative values mean they are censored)
class CoxRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const auto& preds_h = preds.HostVector();
out_gpair->Resize(preds_h.size());
auto& gpair = out_gpair->HostVector();
const std::vector<size_t> &label_order = info.LabelAbsSort();
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
// pre-compute a sum
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
for (omp_ulong i = 0; i < ndata; ++i) {
exp_p_sum += std::exp(preds_h[label_order[i]]);
}
// start calculating grad and hess
const auto& labels = info.labels_.HostVector();
double r_k = 0;
double s_k = 0;
double last_exp_p = 0.0;
double last_abs_y = 0.0;
double accumulated_sum = 0;
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
const size_t ind = label_order[i];
const double p = preds_h[ind];
const double exp_p = std::exp(p);
const double w = info.GetWeight(ind);
const double y = labels[ind];
const double abs_y = std::abs(y);
// only update the denominator after we move forward in time (labels are sorted)
// this is Breslow's method for ties
accumulated_sum += last_exp_p;
if (last_abs_y < abs_y) {
exp_p_sum -= accumulated_sum;
accumulated_sum = 0;
} else {
CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " <<
"MetaInfo::LabelArgsort failed!";
}
if (y > 0) {
r_k += 1.0/exp_p_sum;
s_k += 1.0/(exp_p_sum*exp_p_sum);
}
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
const double hess = exp_p*r_k - exp_p*exp_p * s_k;
gpair.at(ind) = GradientPair(grad * w, hess * w);
last_abs_y = abs_y;
last_exp_p = exp_p;
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = io_preds->HostVector();
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
#pragma omp parallel for schedule(static)
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]);
}
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "cox-nloglik";
}
};
// register the objective function
XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
.describe("Cox regression for censored survival data (negative labels are considered censored).")
.set_body([]() { return new CoxRegression(); });
struct GammaRegressionParam : public dmlc::Parameter<GammaRegressionParam> {
int n_gpus;
int gpu_id;
DMLC_DECLARE_PARAMETER(GammaRegressionParam) {
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
// gamma regression
class GammaRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
}
void GetGradient(const HostDeviceVector<bst_float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size();
out_gpair->Resize(ndata);
label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx];
if (y < 0.0f) {
_label_correct[0] = 0;
}
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
},
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << "GammaRegression: label must be nonnegative";
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);
},
common::Range{0, static_cast<int64_t>(io_preds->Size())}, devices_)
.Eval(io_preds);
}
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
PredTransform(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
return "gamma-nloglik";
}
private:
GPUSet devices_;
GammaRegressionParam param_;
HostDeviceVector<int> label_correct_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(GammaRegressionParam);
// register the objective functions
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
.describe("Gamma regression for severity data.")
.set_body([]() { return new GammaRegression(); });
// declare parameter
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
float tweedie_variance_power;
int n_gpus;
int gpu_id;
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
.describe("Tweedie variance power. Must be between in range [1, 2).");
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms.");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
// tweedie regression
class TweedieRegression : public ObjFunction {
public:
// declare functions
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
}
void GetGradient(const HostDeviceVector<bst_float>& preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair> *out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
const size_t ndata = preds.Size();
out_gpair->Resize(ndata);
label_correct_.Fill(1);
const bool is_null_weight = info.weights_.Size() == 0;
const float rho = param_.tweedie_variance_power;
common::Transform<>::Init(
[=] XGBOOST_DEVICE(size_t _idx,
common::Span<int> _label_correct,
common::Span<GradientPair> _out_gpair,
common::Span<const bst_float> _preds,
common::Span<const bst_float> _labels,
common::Span<const bst_float> _weights) {
bst_float p = _preds[_idx];
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
bst_float y = _labels[_idx];
if (y < 0.0f) {
_label_correct[0] = 0;
}
bst_float grad = -y * expf((1 - rho) * p) + expf((2 - rho) * p);
bst_float hess =
-y * (1 - rho) * \
std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p);
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
},
common::Range{0, static_cast<int64_t>(ndata), 1}, devices_)
.Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (auto const flag : label_correct_h) {
if (flag == 0) {
LOG(FATAL) << "TweedieRegression: label must be nonnegative";
}
}
}
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
common::Transform<>::Init(
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
_preds[_idx] = expf(_preds[_idx]);
},
common::Range{0, static_cast<int64_t>(io_preds->Size())}, devices_)
.Eval(io_preds);
}
bst_float ProbToMargin(bst_float base_score) const override {
return std::log(base_score);
}
const char* DefaultEvalMetric() const override {
std::ostringstream os;
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
std::string metric = os.str();
return metric.c_str();
}
private:
GPUSet devices_;
TweedieRegressionParam param_;
HostDeviceVector<int> label_correct_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(TweedieRegressionParam);
XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
.describe("Tweedie regression for insurance data.")
.set_body([]() { return new TweedieRegression(); });
} // namespace obj
} // namespace xgboost

View File

@ -1,202 +0,0 @@
/*!
* Copyright 2017 XGBoost contributors
*/
// GPU implementation of objective function.
// Necessary to avoid extra copying of data to CPU.
#include <dmlc/omp.h>
#include <thrust/device_vector.h>
#include <thrust/copy.h>
#include <xgboost/logging.h>
#include <xgboost/objective.h>
#include <cmath>
#include <memory>
#include <vector>
#include "../common/span.h"
#include "../common/device_helpers.cuh"
#include "../common/host_device_vector.h"
#include "./regression_loss.h"
namespace xgboost {
namespace obj {
using dh::DVec;
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
struct GPURegLossParam : public dmlc::Parameter<GPURegLossParam> {
float scale_pos_weight;
int n_gpus;
int gpu_id;
// declare parameters
DMLC_DECLARE_PARAMETER(GPURegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor");
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
.describe("Number of GPUs to use for multi-gpu algorithms (NOT IMPLEMENTED)");
DMLC_DECLARE_FIELD(gpu_id)
.set_lower_bound(0)
.set_default(0)
.describe("gpu to use for objective function evaluation");
}
};
// GPU kernel for gradient computation
template<typename Loss>
__global__ void get_gradient_k
(common::Span<GradientPair> out_gpair, common::Span<int> label_correct,
common::Span<const float> preds, common::Span<const float> labels,
const float * __restrict__ weights, int n, float scale_pos_weight) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n)
return;
float p = Loss::PredTransform(preds[i]);
float w = weights == nullptr ? 1.0f : weights[i];
float label = labels[i];
if (label == 1.0f)
w *= scale_pos_weight;
if (!Loss::CheckLabel(label))
atomicAnd(label_correct.data(), 0);
out_gpair[i] = GradientPair
(Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w);
}
// GPU kernel for predicate transformation
template<typename Loss>
__global__ void pred_transform_k(common::Span<float> preds, int n) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= n)
return;
preds[i] = Loss::PredTransform(preds[i]);
}
// regression loss function for evaluation on GPU (eventually)
template<typename Loss>
class GPURegLossObj : public ObjFunction {
protected:
HostDeviceVector<int> label_correct_;
// allocate device data for n elements, do nothing if memory is allocated already
void LazyResize() {
}
public:
GPURegLossObj() {}
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
param_.InitAllowUnknown(args);
CHECK(param_.n_gpus != 0) << "Must have at least one device";
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
label_correct_.Reshard(devices_);
label_correct_.Resize(devices_.Size());
}
void GetGradient(const HostDeviceVector<float> &preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) override {
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
CHECK_EQ(preds.Size(), info.labels_.Size())
<< "labels are not correctly provided"
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
size_t ndata = preds.Size();
preds.Reshard(devices_);
info.labels_.Reshard(devices_);
info.weights_.Reshard(devices_);
out_gpair->Reshard(devices_);
out_gpair->Resize(ndata);
GetGradientDevice(preds, info, iter, out_gpair);
}
private:
void GetGradientDevice(const HostDeviceVector<float>& preds,
const MetaInfo &info,
int iter,
HostDeviceVector<GradientPair>* out_gpair) {
label_correct_.Fill(1);
// run the kernel
#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
for (int i = 0; i < devices_.Size(); ++i) {
int d = devices_[i];
dh::safe_cuda(cudaSetDevice(d));
const int block = 256;
size_t n = preds.DeviceSize(d);
if (n > 0) {
get_gradient_k<Loss><<<dh::DivRoundUp(n, block), block>>>
(out_gpair->DeviceSpan(d), label_correct_.DeviceSpan(d),
preds.DeviceSpan(d), info.labels_.DeviceSpan(d),
info.weights_.Size() > 0 ? info.weights_.DevicePointer(d) : nullptr,
n, param_.scale_pos_weight);
dh::safe_cuda(cudaGetLastError());
}
dh::safe_cuda(cudaDeviceSynchronize());
}
// copy "label correct" flags back to host
std::vector<int>& label_correct_h = label_correct_.HostVector();
for (int i = 0; i < devices_.Size(); ++i) {
if (label_correct_h[i] == 0)
LOG(FATAL) << Loss::LabelErrorMsg();
}
}
public:
const char* DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric();
}
void PredTransform(HostDeviceVector<float> *io_preds) override {
io_preds->Reshard(devices_);
size_t ndata = io_preds->Size();
PredTransformDevice(io_preds);
}
void PredTransformDevice(HostDeviceVector<float>* preds) {
#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
for (int i = 0; i < devices_.Size(); ++i) {
int d = devices_[i];
dh::safe_cuda(cudaSetDevice(d));
const int block = 256;
size_t n = preds->DeviceSize(d);
if (n > 0) {
pred_transform_k<Loss><<<dh::DivRoundUp(n, block), block>>>(
preds->DeviceSpan(d), n);
dh::safe_cuda(cudaGetLastError());
}
dh::safe_cuda(cudaDeviceSynchronize());
}
}
float ProbToMargin(float base_score) const override {
return Loss::ProbToMargin(base_score);
}
protected:
GPURegLossParam param_;
GPUSet devices_;
};
// register the objective functions
DMLC_REGISTER_PARAMETER(GPURegLossParam);
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
.describe("Linear regression (computed on GPU).")
.set_body([]() { return new GPURegLossObj<LinearSquareLoss>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic")
.describe("Logistic regression for probability regression task (computed on GPU).")
.set_body([]() { return new GPURegLossObj<LogisticRegression>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic")
.describe("Logistic regression for binary classification task (computed on GPU).")
.set_body([]() { return new GPURegLossObj<LogisticClassification>(); });
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw")
.describe("Logistic regression for classification, output score "
"before logistic transformation (computed on GPU)")
.set_body([]() { return new GPURegLossObj<LogisticRaw>(); });
} // namespace obj
} // namespace xgboost

View File

@ -14,7 +14,7 @@ struct WriteSymbolFunction {
WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d,
int* input_data_d)
: cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {}
__device__ void operator()(size_t i) {
cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i);
}
@ -28,7 +28,7 @@ struct ReadSymbolFunction {
__device__ void operator()(size_t i) {
output_data_d[i] = ci[i];
}
}
};
TEST(CompressedIterator, TestGPU) {

View File

@ -10,7 +10,7 @@
namespace xgboost {
namespace common {
TEST(gpu_hist_util, TestDeviceSketch) {
void TestDeviceSketch(const GPUSet& devices) {
// create the data
int nrows = 10001;
std::vector<float> test_data(nrows);
@ -28,7 +28,7 @@ TEST(gpu_hist_util, TestDeviceSketch) {
tree::TrainParam p;
p.max_bin = 20;
p.gpu_id = 0;
p.n_gpus = GPUSet::AllVisible().Size();
p.n_gpus = devices.Size();
// ensure that the exact quantiles are found
p.gpu_batch_nrows = nrows * 10;
@ -54,5 +54,17 @@ TEST(gpu_hist_util, TestDeviceSketch) {
delete dmat;
}
TEST(gpu_hist_util, DeviceSketch) {
TestDeviceSketch(GPUSet::Range(0, 1));
}
#if defined(XGBOOST_USE_NCCL)
TEST(gpu_hist_util, MGPU_DeviceSketch) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
TestDeviceSketch(devices);
}
#endif
} // namespace common
} // namespace xgboost

View File

@ -178,18 +178,57 @@ TEST(HostDeviceVector, TestCopy) {
SetCudaSetDeviceHandler(nullptr);
}
// The test is not really useful if n_gpus < 2
TEST(HostDeviceVector, Reshard) {
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
}
HostDeviceVector<int> vec (h_vec);
auto devices = GPUSet::Range(0, 1);
vec.Reshard(devices);
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
ASSERT_EQ(vec.Size(), h_vec.size());
auto span = vec.DeviceSpan(0); // sync to device
vec.Reshard(GPUSet::Empty()); // pull back to cpu, empty devices.
ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_TRUE(vec.Devices().IsEmpty());
auto h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
}
TEST(HostDeviceVector, Span) {
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
vec.Reshard(GPUSet{0, 1});
auto span = vec.DeviceSpan(0);
ASSERT_EQ(vec.DeviceSize(0), span.size());
ASSERT_EQ(vec.DevicePointer(0), span.data());
auto const_span = vec.ConstDeviceSpan(0);
ASSERT_EQ(vec.DeviceSize(0), span.size());
ASSERT_EQ(vec.ConstDevicePointer(0), span.data());
}
// Multi-GPUs' test
#if defined(XGBOOST_USE_NCCL)
TEST(HostDeviceVector, MGPU_Reshard) {
auto devices = GPUSet::AllVisible();
if (devices.Size() < 2) {
LOG(WARNING) << "Not testing in multi-gpu environment.";
return;
}
std::vector<int> h_vec (2345);
for (size_t i = 0; i < h_vec.size(); ++i) {
h_vec[i] = i;
}
HostDeviceVector<int> vec (h_vec);
// Data size for each device.
std::vector<size_t> devices_size (devices.Size());
// From CPU to GPUs.
// Assuming we have > 1 devices.
vec.Reshard(devices);
size_t total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
@ -198,42 +237,26 @@ TEST(HostDeviceVector, Reshard) {
}
ASSERT_EQ(total_size, h_vec.size());
ASSERT_EQ(total_size, vec.Size());
auto h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
vec.Reshard(GPUSet::Empty()); // clear out devices memory
// Reshard from devices to devices with different distribution.
EXPECT_ANY_THROW(
vec.Reshard(GPUDistribution::Granular(devices, 12)));
// Shrink down the number of devices.
vec.Reshard(GPUSet::Range(0, 1));
// All data is drawn back to CPU
vec.Reshard(GPUSet::Empty());
ASSERT_TRUE(vec.Devices().IsEmpty());
ASSERT_EQ(vec.Size(), h_vec.size());
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
vec.Reshard(GPUSet::Empty()); // clear out devices memory
// Grow the number of devices.
vec.Reshard(devices);
vec.Reshard(GPUDistribution::Granular(devices, 12));
total_size = 0;
for (size_t i = 0; i < devices.Size(); ++i) {
total_size += vec.DeviceSize(i);
ASSERT_EQ(devices_size[i], vec.DeviceSize(i));
devices_size[i] = vec.DeviceSize(i);
}
ASSERT_EQ(total_size, h_vec.size());
ASSERT_EQ(total_size, vec.Size());
h_vec_1 = vec.HostVector();
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
}
TEST(HostDeviceVector, Span) {
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
vec.Reshard(GPUSet{0, 1});
auto span = vec.DeviceSpan(0);
ASSERT_EQ(vec.Size(), span.size());
ASSERT_EQ(vec.DevicePointer(0), span.data());
auto const_span = vec.ConstDeviceSpan(0);
ASSERT_EQ(vec.Size(), span.size());
ASSERT_EQ(vec.ConstDevicePointer(0), span.data());
}
#endif
} // namespace common
} // namespace xgboost

View File

@ -7,6 +7,14 @@
#include "../../include/xgboost/base.h"
#include "../../../src/common/span.h"
template <typename Iter>
XGBOOST_DEVICE void InitializeRange(Iter _begin, Iter _end) {
float j = 0;
for (Iter i = _begin; i != _end; ++i, ++j) {
*i = j;
}
}
namespace xgboost {
namespace common {
@ -20,14 +28,6 @@ namespace common {
*(status) = -1; \
}
template <typename Iter>
XGBOOST_DEVICE void InitializeRange(Iter _begin, Iter _end) {
float j = 0;
for (Iter i = _begin; i != _end; ++i, ++j) {
*i = j;
}
}
struct TestTestStatus {
int * status_;

View File

@ -0,0 +1,61 @@
#include <xgboost/base.h>
#include <gtest/gtest.h>
#include <vector>
#include "../../../src/common/host_device_vector.h"
#include "../../../src/common/transform.h"
#include "../../../src/common/span.h"
#include "../helpers.h"
#if defined(__CUDACC__)
#define TRANSFORM_GPU_RANGE GPUSet::Range(0, 1)
#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Range(0, 1))
#else
#define TRANSFORM_GPU_RANGE GPUSet::Empty()
#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Empty())
#endif
template <typename Iter>
void InitializeRange(Iter _begin, Iter _end) {
float j = 0;
for (Iter i = _begin; i != _end; ++i, ++j) {
*i = j;
}
}
namespace xgboost {
namespace common {
template <typename T>
struct TestTransformRange {
void XGBOOST_DEVICE operator()(size_t _idx,
Span<bst_float> _out, Span<const bst_float> _in) {
_out[_idx] = _in[_idx];
}
};
TEST(Transform, DeclareUnifiedTest(Basic)) {
const size_t size {256};
std::vector<bst_float> h_in(size);
std::vector<bst_float> h_out(size);
InitializeRange(h_in.begin(), h_in.end());
std::vector<bst_float> h_sol(size);
InitializeRange(h_sol.begin(), h_sol.end());
const HostDeviceVector<bst_float> in_vec{h_in, TRANSFORM_GPU_DIST};
HostDeviceVector<bst_float> out_vec{h_out, TRANSFORM_GPU_DIST};
out_vec.Fill(0);
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, TRANSFORM_GPU_RANGE)
.Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
} // namespace common
} // namespace xgboost

View File

@ -0,0 +1,43 @@
// This converts all tests from CPU to GPU.
#include "test_transform_range.cc"
#if defined(XGBOOST_USE_NCCL)
namespace xgboost {
namespace common {
// Test here is multi gpu specific
TEST(Transform, MGPU_Basic) {
auto devices = GPUSet::AllVisible();
CHECK_GT(devices.Size(), 1);
const size_t size {256};
std::vector<bst_float> h_in(size);
std::vector<bst_float> h_out(size);
InitializeRange(h_in.begin(), h_in.end());
std::vector<bst_float> h_sol(size);
InitializeRange(h_sol.begin(), h_sol.end());
const HostDeviceVector<bst_float> in_vec {h_in,
GPUDistribution::Block(GPUSet::Empty())};
HostDeviceVector<bst_float> out_vec {h_out,
GPUDistribution::Block(GPUSet::Empty())};
out_vec.Fill(0);
in_vec.Reshard(GPUDistribution::Granular(devices, 8));
out_vec.Reshard(GPUDistribution::Block(devices));
// Granularity is different, resharding will throw.
EXPECT_ANY_THROW(
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, devices)
.Eval(&out_vec, &in_vec));
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size},
devices, false).Eval(&out_vec, &in_vec);
std::vector<bst_float> res = out_vec.HostVector();
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
}
} // namespace xgboost
} // namespace common
#endif

View File

@ -125,3 +125,17 @@ std::shared_ptr<xgboost::DMatrix>* CreateDMatrix(int rows, int columns,
&handle);
return static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
}
namespace xgboost {
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
std::vector<xgboost::bst_float>::const_iterator _end1,
std::vector<xgboost::bst_float>::const_iterator _beg2) {
for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) {
if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){
return false;
}
}
return true;
}
}

View File

@ -15,6 +15,12 @@
#include <xgboost/objective.h>
#include <xgboost/metric.h>
#if defined(__CUDACC__)
#define DeclareUnifiedTest(name) GPU ## name
#else
#define DeclareUnifiedTest(name) name
#endif
std::string TempFileName();
bool FileExists(const std::string name);
@ -46,6 +52,12 @@ xgboost::bst_float GetMetricEval(
std::vector<xgboost::bst_float> labels,
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
namespace xgboost {
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
std::vector<xgboost::bst_float>::const_iterator _end1,
std::vector<xgboost::bst_float>::const_iterator _beg2);
}
/**
* \fn std::shared_ptr<xgboost::DMatrix> CreateDMatrix(int rows, int columns, float sparsity, int seed);
*

View File

@ -4,7 +4,7 @@
#include "../helpers.h"
TEST(Objective, HingeObj) {
TEST(Objective, DeclareUnifiedTest(HingeObj)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -15,6 +15,12 @@ TEST(Objective, HingeObj) {
{ 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
CheckObjFunction(obj,
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f},
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f},
{}, // Empty weight.
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
ASSERT_NO_THROW(obj->DefaultEvalMetric());

View File

@ -0,0 +1 @@
#include "test_hinge.cc"

View File

@ -0,0 +1,60 @@
/*!
* Copyright 2018 XGBoost contributors
*/
#include <xgboost/objective.h>
#include "../helpers.h"
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax");
std::vector<std::pair<std::string, std::string>> args {{"num_class", "3"}};
obj->Configure(args);
CheckObjFunction(obj,
{1, 0, 2, 2, 0, 1}, // preds
{1.0, 0.0}, // labels
{1.0, 1.0}, // weights
{0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad
{0.36, 0.16, 0.44, 0.45, 0.16, 0.37}); // hess
ASSERT_NO_THROW(obj->DefaultEvalMetric());
delete obj;
}
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax");
std::vector<std::pair<std::string, std::string>> args
{std::pair<std::string, std::string>("num_class", "3")};
obj->Configure(args);
xgboost::HostDeviceVector<xgboost::bst_float> io_preds = {2.0f, 0.0f, 1.0f,
1.0f, 0.0f, 2.0f};
std::vector<xgboost::bst_float> out_preds = {0.0f, 2.0f};
obj->PredTransform(&io_preds);
auto& preds = io_preds.HostVector();
for (int i = 0; i < static_cast<int>(io_preds.Size()); ++i) {
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
}
delete obj;
}
TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softprob");
std::vector<std::pair<std::string, std::string>> args
{std::pair<std::string, std::string>("num_class", "3")};
obj->Configure(args);
xgboost::HostDeviceVector<xgboost::bst_float> io_preds = {2.0f, 0.0f, 1.0f};
std::vector<xgboost::bst_float> out_preds = {0.66524096f, 0.09003057f, 0.24472847f};
obj->PredTransform(&io_preds);
auto& preds = io_preds.HostVector();
for (int i = 0; i < static_cast<int>(io_preds.Size()); ++i) {
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
}
delete obj;
}

View File

@ -0,0 +1 @@
#include "test_multiclass_obj.cc"

View File

@ -1,9 +1,11 @@
// Copyright by Contributors
/*!
* Copyright 2017-2018 XGBoost contributors
*/
#include <xgboost/objective.h>
#include "../helpers.h"
TEST(Objective, LinearRegressionGPair) {
TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:linear");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -13,27 +15,32 @@ TEST(Objective, LinearRegressionGPair) {
{1, 1, 1, 1, 1, 1, 1, 1},
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
{1, 1, 1, 1, 1, 1, 1, 1});
CheckObjFunction(obj,
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{0, 0, 0, 0, 1, 1, 1, 1},
{}, // empty weight
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
{1, 1, 1, 1, 1, 1, 1, 1});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
delete obj;
}
TEST(Objective, LogisticRegressionGPair) {
TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, // preds
{ 0, 0, 0, 0, 1, 1, 1, 1}, // labels
{ 1, 1, 1, 1, 1, 1, 1, 1}, // weights
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, // out_grad
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); // out_hess
delete obj;
}
TEST(Objective, LogisticRegressionBasic) {
TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -61,7 +68,7 @@ TEST(Objective, LogisticRegressionBasic) {
delete obj;
}
TEST(Objective, LogisticRawGPair) {
TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:logitraw");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -75,7 +82,7 @@ TEST(Objective, LogisticRawGPair) {
delete obj;
}
TEST(Objective, PoissonRegressionGPair) {
TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson");
std::vector<std::pair<std::string, std::string> > args;
args.push_back(std::make_pair("max_delta_step", "0.1f"));
@ -86,11 +93,16 @@ TEST(Objective, PoissonRegressionGPair) {
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f},
{1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f});
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{}, // Empty weight
{ 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f},
{1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f});
delete obj;
}
TEST(Objective, PoissonRegressionBasic) {
TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -116,7 +128,7 @@ TEST(Objective, PoissonRegressionBasic) {
delete obj;
}
TEST(Objective, GammaRegressionGPair) {
TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -126,11 +138,16 @@ TEST(Objective, GammaRegressionGPair) {
{1, 1, 1, 1, 1, 1, 1, 1},
{1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f},
{0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f});
CheckObjFunction(obj,
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{0, 0, 0, 0, 1, 1, 1, 1},
{}, // Empty weight
{1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f},
{0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f});
delete obj;
}
TEST(Objective, GammaRegressionBasic) {
TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -156,7 +173,7 @@ TEST(Objective, GammaRegressionBasic) {
delete obj;
}
TEST(Objective, TweedieRegressionGPair) {
TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie");
std::vector<std::pair<std::string, std::string> > args;
args.push_back(std::make_pair("tweedie_variance_power", "1.1f"));
@ -167,11 +184,17 @@ TEST(Objective, TweedieRegressionGPair) {
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f},
{0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f});
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{}, // Empty weight.
{ 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f},
{0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f});
delete obj;
}
TEST(Objective, TweedieRegressionBasic) {
TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
@ -197,6 +220,9 @@ TEST(Objective, TweedieRegressionBasic) {
delete obj;
}
// CoxRegression not implemented in GPU code, no need for testing.
#if !defined(__CUDACC__)
TEST(Objective, CoxRegressionGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("survival:cox");
std::vector<std::pair<std::string, std::string> > args;
@ -210,3 +236,4 @@ TEST(Objective, CoxRegressionGPair) {
delete obj;
}
#endif

View File

@ -1,78 +1,6 @@
/*!
* Copyright 2017 XGBoost contributors
* Copyright 2018 XGBoost contributors
*/
#include <xgboost/objective.h>
// Dummy file to keep the CUDA tests.
#include "../helpers.h"
TEST(Objective, GPULinearRegressionGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:linear");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{0, 0, 0, 0, 1, 1, 1, 1},
{1, 1, 1, 1, 1, 1, 1, 1},
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
{1, 1, 1, 1, 1, 1, 1, 1});
ASSERT_NO_THROW(obj->DefaultEvalMetric());
delete obj;
}
TEST(Objective, GPULogisticRegressionGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
delete obj;
}
TEST(Objective, GPULogisticRegressionBasic) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
// test label validation
EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0}))
<< "Expected error when label not in range [0,1f] for LogisticRegression";
// test ProbToMargin
EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f);
EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f);
EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f);
EXPECT_ANY_THROW(obj->ProbToMargin(10))
<< "Expected error when base_score not in range [0,1f] for LogisticRegression";
// test PredTransform
xgboost::HostDeviceVector<xgboost::bst_float> io_preds = {0, 0.1f, 0.5f, 0.9f, 1};
std::vector<xgboost::bst_float> out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f};
obj->PredTransform(&io_preds);
auto& preds = io_preds.HostVector();
for (int i = 0; i < static_cast<int>(io_preds.Size()); ++i) {
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
}
delete obj;
}
TEST(Objective, GPULogisticRawGPair) {
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:binary:logitraw");
std::vector<std::pair<std::string, std::string> > args;
obj->Configure(args);
CheckObjFunction(obj,
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
{ 0, 0, 0, 0, 1, 1, 1, 1},
{ 1, 1, 1, 1, 1, 1, 1, 1},
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
delete obj;
}
#include "test_regression_obj.cc"