Implement transform to reduce CPU/GPU code duplication. (#3643)
* Implement Transform class. * Add tests for softmax. * Use Transform in regression, softmax and hinge objectives, except for Cox. * Mark old gpu objective functions deprecated. * static_assert for softmax. * Split up multi-gpu tests.
This commit is contained in:
parent
87aca8c244
commit
d594b11f35
@ -1,9 +1,11 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* \file common.cc
|
||||
* \brief Enable all kinds of global variables in common.
|
||||
*/
|
||||
#include <dmlc/thread_local.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "./random.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
@ -11,7 +11,7 @@ int AllVisibleImpl::AllVisible() {
|
||||
// When compiled with CUDA but running on CPU only device,
|
||||
// cudaGetDeviceCount will fail.
|
||||
dh::safe_cuda(cudaGetDeviceCount(&n_visgpus));
|
||||
} catch(const std::exception& e) {
|
||||
} catch(const thrust::system::system_error& err) {
|
||||
return 0;
|
||||
}
|
||||
return n_visgpus;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* \file common.h
|
||||
* \brief Common utilities
|
||||
*/
|
||||
@ -19,6 +19,13 @@
|
||||
#if defined(__CUDACC__)
|
||||
#include <thrust/system/cuda/error.h>
|
||||
#include <thrust/system_error.h>
|
||||
|
||||
#define WITH_CUDA() true
|
||||
|
||||
#else
|
||||
|
||||
#define WITH_CUDA() false
|
||||
|
||||
#endif
|
||||
|
||||
namespace dh {
|
||||
@ -29,11 +36,11 @@ namespace dh {
|
||||
#define safe_cuda(ans) ThrowOnCudaError((ans), __FILE__, __LINE__)
|
||||
|
||||
inline cudaError_t ThrowOnCudaError(cudaError_t code, const char *file,
|
||||
int line) {
|
||||
int line) {
|
||||
if (code != cudaSuccess) {
|
||||
throw thrust::system_error(code, thrust::cuda_category(),
|
||||
std::string{file} + "(" + // NOLINT
|
||||
std::to_string(line) + ")");
|
||||
LOG(FATAL) << thrust::system_error(code, thrust::cuda_category(),
|
||||
std::string{file} + ": " + // NOLINT
|
||||
std::to_string(line)).what();
|
||||
}
|
||||
return code;
|
||||
}
|
||||
@ -70,13 +77,13 @@ inline std::string ToString(const T& data) {
|
||||
*/
|
||||
class Range {
|
||||
public:
|
||||
using DifferenceType = int64_t;
|
||||
|
||||
class Iterator {
|
||||
friend class Range;
|
||||
|
||||
public:
|
||||
using DifferenceType = int64_t;
|
||||
|
||||
XGBOOST_DEVICE int64_t operator*() const { return i_; }
|
||||
XGBOOST_DEVICE DifferenceType operator*() const { return i_; }
|
||||
XGBOOST_DEVICE const Iterator &operator++() {
|
||||
i_ += step_;
|
||||
return *this;
|
||||
@ -97,8 +104,8 @@ class Range {
|
||||
XGBOOST_DEVICE void Step(DifferenceType s) { step_ = s; }
|
||||
|
||||
protected:
|
||||
XGBOOST_DEVICE explicit Iterator(int64_t start) : i_(start) {}
|
||||
XGBOOST_DEVICE explicit Iterator(int64_t start, int step) :
|
||||
XGBOOST_DEVICE explicit Iterator(DifferenceType start) : i_(start) {}
|
||||
XGBOOST_DEVICE explicit Iterator(DifferenceType start, DifferenceType step) :
|
||||
i_{start}, step_{step} {}
|
||||
|
||||
public:
|
||||
@ -109,9 +116,10 @@ class Range {
|
||||
XGBOOST_DEVICE Iterator begin() const { return begin_; } // NOLINT
|
||||
XGBOOST_DEVICE Iterator end() const { return end_; } // NOLINT
|
||||
|
||||
XGBOOST_DEVICE Range(int64_t begin, int64_t end)
|
||||
XGBOOST_DEVICE Range(DifferenceType begin, DifferenceType end)
|
||||
: begin_(begin), end_(end) {}
|
||||
XGBOOST_DEVICE Range(int64_t begin, int64_t end, Iterator::DifferenceType step)
|
||||
XGBOOST_DEVICE Range(DifferenceType begin, DifferenceType end,
|
||||
DifferenceType step)
|
||||
: begin_(begin, step), end_(end) {}
|
||||
|
||||
XGBOOST_DEVICE bool operator==(const Range& other) const {
|
||||
@ -121,9 +129,7 @@ class Range {
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE void Step(Iterator::DifferenceType s) { begin_.Step(s); }
|
||||
|
||||
XGBOOST_DEVICE Iterator::DifferenceType GetStep() const { return begin_.step_; }
|
||||
XGBOOST_DEVICE void Step(DifferenceType s) { begin_.Step(s); }
|
||||
|
||||
private:
|
||||
Iterator begin_;
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
#include <xgboost/logging.h>
|
||||
|
||||
#include "common.h"
|
||||
#include "span.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <chrono>
|
||||
@ -955,7 +956,7 @@ class SaveCudaContext {
|
||||
// cudaGetDevice will fail.
|
||||
try {
|
||||
safe_cuda(cudaGetDevice(&saved_device_));
|
||||
} catch (thrust::system::system_error & err) {
|
||||
} catch (const thrust::system::system_error & err) {
|
||||
saved_device_ = -1;
|
||||
}
|
||||
func();
|
||||
@ -1035,4 +1036,22 @@ ReduceT ReduceShards(std::vector<ShardT> *shards, FunctionT f) {
|
||||
};
|
||||
return std::accumulate(sums.begin(), sums.end(), ReduceT());
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename IndexT = typename xgboost::common::Span<T>::index_type>
|
||||
xgboost::common::Span<T> ToSpan(
|
||||
thrust::device_vector<T>& vec,
|
||||
IndexT offset = 0,
|
||||
IndexT size = -1) {
|
||||
size = size == -1 ? vec.size() : size;
|
||||
CHECK_LE(offset + size, vec.size());
|
||||
return {vec.data().get() + offset, static_cast<IndexT>(size)};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
xgboost::common::Span<T> ToSpan(thrust::device_vector<T>& vec,
|
||||
size_t offset, size_t size) {
|
||||
using IndexT = typename xgboost::common::Span<T>::index_type;
|
||||
return ToSpan(vec, static_cast<IndexT>(offset), static_cast<IndexT>(size));
|
||||
}
|
||||
} // namespace dh
|
||||
|
||||
@ -116,6 +116,7 @@ struct HostDeviceVectorImpl {
|
||||
int ndevices = vec_->distribution_.devices_.Size();
|
||||
start_ = vec_->distribution_.ShardStart(new_size, index_);
|
||||
proper_size_ = vec_->distribution_.ShardProperSize(new_size, index_);
|
||||
// The size on this device.
|
||||
size_t size_d = vec_->distribution_.ShardSize(new_size, index_);
|
||||
SetDevice();
|
||||
data_.resize(size_d);
|
||||
@ -230,7 +231,7 @@ struct HostDeviceVectorImpl {
|
||||
CHECK(devices.Contains(device));
|
||||
LazySyncDevice(device, GPUAccess::kWrite);
|
||||
return {shards_[devices.Index(device)].data_.data().get(),
|
||||
static_cast<typename common::Span<T>::index_type>(DeviceSize(device))};
|
||||
static_cast<typename common::Span<T>::index_type>(DeviceSize(device))};
|
||||
}
|
||||
|
||||
common::Span<const T> ConstDeviceSpan(int device) {
|
||||
@ -238,7 +239,7 @@ struct HostDeviceVectorImpl {
|
||||
CHECK(devices.Contains(device));
|
||||
LazySyncDevice(device, GPUAccess::kRead);
|
||||
return {shards_[devices.Index(device)].data_.data().get(),
|
||||
static_cast<typename common::Span<const T>::index_type>(DeviceSize(device))};
|
||||
static_cast<typename common::Span<const T>::index_type>(DeviceSize(device))};
|
||||
}
|
||||
|
||||
size_t DeviceSize(int device) {
|
||||
@ -289,7 +290,6 @@ struct HostDeviceVectorImpl {
|
||||
data_h_.size() * sizeof(T),
|
||||
cudaMemcpyHostToDevice));
|
||||
} else {
|
||||
//
|
||||
dh::ExecuteShards(&shards_, [&](DeviceShard& shard) { shard.GatherTo(begin); });
|
||||
}
|
||||
}
|
||||
@ -304,14 +304,20 @@ struct HostDeviceVectorImpl {
|
||||
|
||||
void Copy(HostDeviceVectorImpl<T>* other) {
|
||||
CHECK_EQ(Size(), other->Size());
|
||||
// Data is on host.
|
||||
if (perm_h_.CanWrite() && other->perm_h_.CanWrite()) {
|
||||
std::copy(other->data_h_.begin(), other->data_h_.end(), data_h_.begin());
|
||||
} else {
|
||||
CHECK(distribution_ == other->distribution_);
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
|
||||
shard.Copy(&other->shards_[i]);
|
||||
});
|
||||
return;
|
||||
}
|
||||
// Data is on device;
|
||||
if (distribution_ != other->distribution_) {
|
||||
distribution_ = GPUDistribution();
|
||||
Reshard(other->Distribution());
|
||||
size_d_ = other->size_d_;
|
||||
}
|
||||
dh::ExecuteIndexShards(&shards_, [&](int i, DeviceShard& shard) {
|
||||
shard.Copy(&other->shards_[i]);
|
||||
});
|
||||
}
|
||||
|
||||
void Copy(const std::vector<T>& other) {
|
||||
|
||||
@ -111,8 +111,11 @@ class GPUDistribution {
|
||||
}
|
||||
|
||||
friend bool operator==(const GPUDistribution& a, const GPUDistribution& b) {
|
||||
return a.devices_ == b.devices_ && a.granularity_ == b.granularity_ &&
|
||||
a.overlap_ == b.overlap_ && a.offsets_ == b.offsets_;
|
||||
bool const res = a.devices_ == b.devices_ &&
|
||||
a.granularity_ == b.granularity_ &&
|
||||
a.overlap_ == b.overlap_ &&
|
||||
a.offsets_ == b.offsets_;
|
||||
return res;
|
||||
}
|
||||
|
||||
friend bool operator!=(const GPUDistribution& a, const GPUDistribution& b) {
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "avx_helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
@ -29,22 +30,31 @@ inline avx::Float8 Sigmoid(avx::Float8 x) {
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief do inplace softmax transformaton on p_rec
|
||||
* \param p_rec the input/output vector of the values.
|
||||
* \brief Do inplace softmax transformaton on start to end
|
||||
*
|
||||
* \tparam Iterator Input iterator type
|
||||
*
|
||||
* \param start Start iterator of input
|
||||
* \param end end iterator of input
|
||||
*/
|
||||
inline void Softmax(std::vector<float>* p_rec) {
|
||||
std::vector<float> &rec = *p_rec;
|
||||
float wmax = rec[0];
|
||||
for (size_t i = 1; i < rec.size(); ++i) {
|
||||
wmax = std::max(rec[i], wmax);
|
||||
template <typename Iterator>
|
||||
XGBOOST_DEVICE inline void Softmax(Iterator start, Iterator end) {
|
||||
static_assert(std::is_same<bst_float,
|
||||
typename std::remove_reference<
|
||||
decltype(std::declval<Iterator>().operator*())>::type
|
||||
>::value,
|
||||
"Values should be of type bst_float");
|
||||
bst_float wmax = *start;
|
||||
for (Iterator i = start+1; i != end; ++i) {
|
||||
wmax = fmaxf(*i, wmax);
|
||||
}
|
||||
double wsum = 0.0f;
|
||||
for (float & elem : rec) {
|
||||
elem = std::exp(elem - wmax);
|
||||
wsum += elem;
|
||||
for (Iterator i = start; i != end; ++i) {
|
||||
*i = expf(*i - wmax);
|
||||
wsum += *i;
|
||||
}
|
||||
for (float & elem : rec) {
|
||||
elem /= static_cast<float>(wsum);
|
||||
for (Iterator i = start; i != end; ++i) {
|
||||
*i /= static_cast<float>(wsum);
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,7 +66,7 @@ inline void Softmax(std::vector<float>* p_rec) {
|
||||
* \tparam Iterator The type of the iterator.
|
||||
*/
|
||||
template<typename Iterator>
|
||||
inline Iterator FindMaxIndex(Iterator begin, Iterator end) {
|
||||
XGBOOST_DEVICE inline Iterator FindMaxIndex(Iterator begin, Iterator end) {
|
||||
Iterator maxit = begin;
|
||||
for (Iterator it = begin; it != end; ++it) {
|
||||
if (*it > *maxit) maxit = it;
|
||||
|
||||
@ -49,7 +49,7 @@
|
||||
*
|
||||
* https://github.com/Microsoft/GSL/pull/664
|
||||
*
|
||||
* FIXME: Group these MSVC workarounds into a manageable place.
|
||||
* TODO(trivialfis): Group these MSVC workarounds into a manageable place.
|
||||
*/
|
||||
#if defined(_MSC_VER) && _MSC_VER < 1910
|
||||
|
||||
@ -68,7 +68,7 @@ namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
// Usual logging facility is not available inside device code.
|
||||
// FIXME: Make dmlc check more generic.
|
||||
// TODO(trivialfis): Make dmlc check more generic.
|
||||
#define KERNEL_CHECK(cond) \
|
||||
do { \
|
||||
if (!(cond)) { \
|
||||
@ -104,11 +104,11 @@ constexpr detail::ptrdiff_t dynamic_extent = -1; // NOLINT
|
||||
|
||||
enum class byte : unsigned char {}; // NOLINT
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <class ElementType, detail::ptrdiff_t Extent = dynamic_extent>
|
||||
template <class ElementType, detail::ptrdiff_t Extent>
|
||||
class Span;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename SpanType, bool IsConst>
|
||||
class SpanIterator {
|
||||
using ElementType = typename SpanType::element_type;
|
||||
|
||||
204
src/common/transform.h
Normal file
204
src/common/transform.h
Normal file
@ -0,0 +1,204 @@
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#ifndef XGBOOST_COMMON_TRANSFORM_H_
|
||||
#define XGBOOST_COMMON_TRANSFORM_H_
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <vector>
|
||||
#include <type_traits> // enable_if
|
||||
|
||||
#include "host_device_vector.h"
|
||||
#include "common.h"
|
||||
#include "span.h"
|
||||
|
||||
#if defined (__CUDACC__)
|
||||
#include "device_helpers.cuh"
|
||||
#endif
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
constexpr size_t kBlockThreads = 256;
|
||||
|
||||
namespace detail {
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
template <typename Functor, typename... SpanType>
|
||||
__global__ void LaunchCUDAKernel(Functor _func, Range _range,
|
||||
SpanType... _spans) {
|
||||
for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) {
|
||||
_func(i, _spans...);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/*! \brief Do Transformation on HostDeviceVectors.
|
||||
*
|
||||
* \tparam CompiledWithCuda A bool parameter used to distinguish compilation
|
||||
* trajectories, users do not need to use it.
|
||||
*
|
||||
* Note: Using Transform is a VERY tricky thing to do. Transform uses template
|
||||
* argument to duplicate itself into two different types, one for CPU,
|
||||
* another for CUDA. The trick is not without its flaw:
|
||||
*
|
||||
* If you use it in a function that can be compiled by both nvcc and host
|
||||
* compiler, the behaviour is un-defined! Because your function is NOT
|
||||
* duplicated by `CompiledWithCuda`. At link time, cuda compiler resolution
|
||||
* will merge functions with same signature.
|
||||
*/
|
||||
template <bool CompiledWithCuda = WITH_CUDA()>
|
||||
class Transform {
|
||||
private:
|
||||
template <typename Functor>
|
||||
struct Evaluator {
|
||||
public:
|
||||
Evaluator(Functor func, Range range, GPUSet devices, bool reshard) :
|
||||
func_(func), range_{std::move(range)},
|
||||
distribution_{std::move(GPUDistribution::Block(devices))},
|
||||
reshard_{reshard} {}
|
||||
Evaluator(Functor func, Range range, GPUDistribution dist,
|
||||
bool reshard) :
|
||||
func_(func), range_{std::move(range)}, distribution_{std::move(dist)},
|
||||
reshard_{reshard} {}
|
||||
|
||||
/*!
|
||||
* \brief Evaluate the functor with input pointers to HostDeviceVector.
|
||||
*
|
||||
* \tparam HDV... HostDeviceVectors type.
|
||||
* \param vectors Pointers to HostDeviceVector.
|
||||
*/
|
||||
template <typename... HDV>
|
||||
void Eval(HDV... vectors) const {
|
||||
bool on_device = !distribution_.IsEmpty();
|
||||
|
||||
if (on_device) {
|
||||
LaunchCUDA(func_, vectors...);
|
||||
} else {
|
||||
LaunchCPU(func_, vectors...);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// CUDA UnpackHDV
|
||||
template <typename T>
|
||||
Span<T> UnpackHDV(HostDeviceVector<T>* _vec, int _device) const {
|
||||
return _vec->DeviceSpan(_device);
|
||||
}
|
||||
template <typename T>
|
||||
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec, int _device) const {
|
||||
return _vec->ConstDeviceSpan(_device);
|
||||
}
|
||||
// CPU UnpackHDV
|
||||
template <typename T>
|
||||
Span<T> UnpackHDV(HostDeviceVector<T>* _vec) const {
|
||||
return Span<T> {_vec->HostPointer(),
|
||||
static_cast<typename Span<T>::index_type>(_vec->Size())};
|
||||
}
|
||||
template <typename T>
|
||||
Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec) const {
|
||||
return Span<T const> {_vec->ConstHostPointer(),
|
||||
static_cast<typename Span<T>::index_type>(_vec->Size())};
|
||||
}
|
||||
// Recursive unpack for Reshard.
|
||||
template <typename T>
|
||||
void UnpackReshard(GPUDistribution dist, const HostDeviceVector<T>* vector) const {
|
||||
vector->Reshard(dist);
|
||||
}
|
||||
template <typename Head, typename... Rest>
|
||||
void UnpackReshard(GPUDistribution dist,
|
||||
const HostDeviceVector<Head>* _vector,
|
||||
const HostDeviceVector<Rest>*... _vectors) const {
|
||||
_vector->Reshard(dist);
|
||||
UnpackReshard(dist, _vectors...);
|
||||
}
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
template <typename std::enable_if<CompiledWithCuda>::type* = nullptr,
|
||||
typename... HDV>
|
||||
void LaunchCUDA(Functor _func, HDV*... _vectors) const {
|
||||
if (reshard_)
|
||||
UnpackReshard(distribution_, _vectors...);
|
||||
|
||||
GPUSet devices = distribution_.Devices();
|
||||
size_t range_size = *range_.end() - *range_.begin();
|
||||
#pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
|
||||
for (omp_ulong i = 0; i < devices.Size(); ++i) {
|
||||
int d = devices.Index(i);
|
||||
// Ignore other attributes of GPUDistribution for spliting index.
|
||||
size_t shard_size =
|
||||
GPUDistribution::Block(devices).ShardSize(range_size, d);
|
||||
Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
|
||||
dh::safe_cuda(cudaSetDevice(d));
|
||||
const int GRID_SIZE =
|
||||
static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
|
||||
|
||||
detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
|
||||
_func, shard_range, UnpackHDV(_vectors, d)...);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
#else
|
||||
/*! \brief Dummy funtion defined when compiling for CPU. */
|
||||
template <typename std::enable_if<!CompiledWithCuda>::type* = nullptr,
|
||||
typename... HDV>
|
||||
void LaunchCUDA(Functor _func, HDV*... _vectors) const {
|
||||
LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA();
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename... HDV>
|
||||
void LaunchCPU(Functor func, HDV*... vectors) const {
|
||||
auto end = *(range_.end());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong idx = 0; idx < end; ++idx) {
|
||||
func(idx, UnpackHDV(vectors)...);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/*! \brief Callable object. */
|
||||
Functor func_;
|
||||
/*! \brief Range object specifying parallel threads index range. */
|
||||
Range range_;
|
||||
/*! \brief Whether resharding for vectors is required. */
|
||||
bool reshard_;
|
||||
GPUDistribution distribution_;
|
||||
};
|
||||
|
||||
public:
|
||||
/*!
|
||||
* \brief Initialize a Transform object.
|
||||
*
|
||||
* \tparam Functor A callable object type.
|
||||
* \return A Evaluator having one method Eval.
|
||||
*
|
||||
* \param func A callable object, accepting a size_t thread index,
|
||||
* followed by a set of Span classes.
|
||||
* \param range Range object specifying parallel threads index range.
|
||||
* \param devices GPUSet specifying GPUs to use, when compiling for CPU,
|
||||
* this should be GPUSet::Empty().
|
||||
* \param reshard Whether Reshard for HostDeviceVector is needed.
|
||||
*/
|
||||
template <typename Functor>
|
||||
static Evaluator<Functor> Init(Functor func, Range const range,
|
||||
GPUSet const devices,
|
||||
bool const reshard = true) {
|
||||
return Evaluator<Functor> {func, std::move(range), std::move(devices), reshard};
|
||||
}
|
||||
template <typename Functor>
|
||||
static Evaluator<Functor> Init(Functor func, Range const range,
|
||||
GPUDistribution const dist,
|
||||
bool const reshard = true) {
|
||||
return Evaluator<Functor> {func, std::move(range), std::move(dist), reshard};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
#endif // XGBOOST_COMMON_TRANSFORM_H_
|
||||
@ -1,73 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file hinge.cc
|
||||
* \brief Provides an implementation of the hinge loss function
|
||||
* \author Henry Gouk
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include "../common/math.h"
|
||||
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(hinge);
|
||||
|
||||
class HingeObj : public ObjFunction {
|
||||
public:
|
||||
HingeObj() = default;
|
||||
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||
// This objective does not take any parameters
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size()
|
||||
<< ", label.size=" << info.labels_.Size();
|
||||
const auto& preds_h = preds.HostVector();
|
||||
const auto& labels_h = info.labels_.HostVector();
|
||||
const auto& weights_h = info.weights_.HostVector();
|
||||
|
||||
out_gpair->Resize(preds_h.size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
|
||||
for (size_t i = 0; i < preds_h.size(); ++i) {
|
||||
auto y = labels_h[i] * 2.0 - 1.0;
|
||||
bst_float p = preds_h[i];
|
||||
bst_float w = weights_h.size() > 0 ? weights_h[i] : 1.0f;
|
||||
bst_float g, h;
|
||||
if (p * y < 1.0) {
|
||||
g = -y * w;
|
||||
h = w;
|
||||
} else {
|
||||
g = 0.0;
|
||||
h = std::numeric_limits<bst_float>::min();
|
||||
}
|
||||
gpair[i] = GradientPair(g, h);
|
||||
}
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
for (auto& p : preds) {
|
||||
p = p > 0.0 ? 1.0 : 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "error";
|
||||
}
|
||||
};
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
|
||||
.describe("Hinge loss. Expects labels to be in [0,1f]")
|
||||
.set_body([]() { return new HingeObj(); });
|
||||
DMLC_REGISTRY_FILE_TAG(hinge_obj);
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
#include "hinge.cu"
|
||||
#endif
|
||||
|
||||
109
src/objective/hinge.cu
Normal file
109
src/objective/hinge.cu
Normal file
@ -0,0 +1,109 @@
|
||||
/*!
|
||||
* Copyright 2018 by Contributors
|
||||
* \file hinge.cc
|
||||
* \brief Provides an implementation of the hinge loss function
|
||||
* \author Henry Gouk
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
#include "../common/math.h"
|
||||
#include "../common/transform.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/span.h"
|
||||
#include "../common/host_device_vector.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
DMLC_REGISTRY_FILE_TAG(hinge_obj_gpu);
|
||||
#endif
|
||||
|
||||
struct HingeObjParam : public dmlc::Parameter<HingeObjParam> {
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(HingeObjParam) {
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(0).set_lower_bound(0)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
class HingeObj : public ObjFunction {
|
||||
public:
|
||||
HingeObj() = default;
|
||||
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size()
|
||||
<< ", label.size=" << info.labels_.Size();
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
const size_t ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
bst_float p = _preds[_idx];
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float y = _labels[_idx] * 2.0 - 1.0;
|
||||
bst_float g, h;
|
||||
if (p * y < 1.0) {
|
||||
g = -y * w;
|
||||
h = w;
|
||||
} else {
|
||||
g = 0.0;
|
||||
h = std::numeric_limits<bst_float>::min();
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair(g, h);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
|
||||
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = _preds[_idx] > 0.0 ? 1.0 : 0.0;
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(io_preds->Size()), 1}, devices_)
|
||||
.Eval(io_preds);
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "error";
|
||||
}
|
||||
|
||||
private:
|
||||
GPUSet devices_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
HingeObjParam param_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(HingeObjParam);
|
||||
// register the objective functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(HingeObj, "binary:hinge")
|
||||
.describe("Hinge loss. Expects labels to be in [0,1f]")
|
||||
.set_body([]() { return new HingeObj(); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@ -1,141 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file multi_class.cc
|
||||
* \brief Definition of multi-class classification objectives.
|
||||
* \author Tianqi Chen
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "../common/math.h"
|
||||
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj);
|
||||
|
||||
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
int num_class;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
|
||||
DMLC_DECLARE_FIELD(num_class).set_lower_bound(1)
|
||||
.describe("Number of output class in the multi-class classification.");
|
||||
}
|
||||
};
|
||||
|
||||
class SoftmaxMultiClassObj : public ObjFunction {
|
||||
public:
|
||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||
: output_prob_(output_prob) {
|
||||
}
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size()))
|
||||
<< "SoftmaxMultiClassObj: label size and pred size does not match";
|
||||
const std::vector<bst_float>& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds_h.size());
|
||||
std::vector<GradientPair>& gpair = out_gpair->HostVector();
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<omp_ulong>(preds_h.size() / nclass);
|
||||
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
int label_error = 0;
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<bst_float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds_h[i * nclass + k];
|
||||
}
|
||||
common::Softmax(&rec);
|
||||
auto label = static_cast<int>(labels[i]);
|
||||
if (label < 0 || label >= nclass) {
|
||||
label_error = label; label = 0;
|
||||
}
|
||||
const bst_float wt = info.GetWeight(i);
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
bst_float p = rec[k];
|
||||
const float eps = 1e-16f;
|
||||
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, eps);
|
||||
if (label == k) {
|
||||
gpair[i * nclass + k] = GradientPair((p - 1.0f) * wt, h);
|
||||
} else {
|
||||
gpair[i * nclass + k] = GradientPair(p* wt, h);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(label_error >= 0 && label_error < nclass)
|
||||
<< "SoftmaxMultiClassObj: label must be in [0, num_class),"
|
||||
<< " num_class=" << nclass
|
||||
<< " but found " << label_error << " in label.";
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
this->Transform(io_preds, output_prob_);
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
this->Transform(io_preds, true);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "merror";
|
||||
}
|
||||
|
||||
private:
|
||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
std::vector<bst_float> tmp;
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<omp_ulong>(preds.size() / nclass);
|
||||
if (!prob) tmp.resize(ndata);
|
||||
|
||||
#pragma omp parallel
|
||||
{
|
||||
std::vector<bst_float> rec(nclass);
|
||||
#pragma omp for schedule(static)
|
||||
for (omp_ulong j = 0; j < ndata; ++j) {
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
rec[k] = preds[j * nclass + k];
|
||||
}
|
||||
if (!prob) {
|
||||
tmp[j] = static_cast<bst_float>(
|
||||
common::FindMaxIndex(rec.begin(), rec.end()) - rec.begin());
|
||||
} else {
|
||||
common::Softmax(&rec);
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
preds[j * nclass + k] = rec[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!prob) preds = tmp;
|
||||
}
|
||||
// output probability
|
||||
bool output_prob_;
|
||||
// parameter
|
||||
SoftmaxMultiClassParam param_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax")
|
||||
.describe("Softmax for multi-class classification, output class index.")
|
||||
.set_body([]() { return new SoftmaxMultiClassObj(false); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob")
|
||||
.describe("Softmax for multi-class classification, output probability distribution.")
|
||||
.set_body([]() { return new SoftmaxMultiClassObj(true); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
#include "multiclass_obj.cu"
|
||||
#endif
|
||||
|
||||
195
src/objective/multiclass_obj.cu
Normal file
195
src/objective/multiclass_obj.cu
Normal file
@ -0,0 +1,195 @@
|
||||
/*!
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* \file multi_class.cc
|
||||
* \brief Definition of multi-class classification objectives.
|
||||
* \author Tianqi Chen
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <dmlc/parameter.h>
|
||||
#include <xgboost/data.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
#include "../common/math.h"
|
||||
#include "../common/transform.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
DMLC_REGISTRY_FILE_TAG(multiclass_obj_gpu);
|
||||
#endif
|
||||
|
||||
struct SoftmaxMultiClassParam : public dmlc::Parameter<SoftmaxMultiClassParam> {
|
||||
int num_class;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(SoftmaxMultiClassParam) {
|
||||
DMLC_DECLARE_FIELD(num_class).set_lower_bound(1)
|
||||
.describe("Number of output class in the multi-class classification.");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
// TODO(trivialfis): Currently the resharding in softmax is less than ideal
|
||||
// due to repeated copying data between CPU and GPUs. Maybe we just use single
|
||||
// GPU?
|
||||
class SoftmaxMultiClassObj : public ObjFunction {
|
||||
public:
|
||||
explicit SoftmaxMultiClassObj(bool output_prob)
|
||||
: output_prob_(output_prob) {
|
||||
}
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
|
||||
}
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo& info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK(preds.Size() == (static_cast<size_t>(param_.num_class) * info.labels_.Size()))
|
||||
<< "SoftmaxMultiClassObj: label size and pred size does not match";
|
||||
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<int64_t>(preds.Size() / nclass);
|
||||
|
||||
// clear out device memory;
|
||||
out_gpair->Reshard(GPUSet::Empty());
|
||||
preds.Reshard(GPUSet::Empty());
|
||||
|
||||
out_gpair->Reshard(GPUDistribution::Granular(devices_, nclass));
|
||||
info.labels_.Reshard(GPUDistribution::Block(devices_));
|
||||
info.weights_.Reshard(GPUDistribution::Block(devices_));
|
||||
preds.Reshard(GPUDistribution::Granular(devices_, nclass));
|
||||
label_correct_.Reshard(GPUDistribution::Block(devices_));
|
||||
|
||||
out_gpair->Resize(preds.Size());
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t idx,
|
||||
common::Span<GradientPair> gpair,
|
||||
common::Span<bst_float const> labels,
|
||||
common::Span<bst_float const> preds,
|
||||
common::Span<bst_float const> weights,
|
||||
common::Span<int> _label_correct) {
|
||||
common::Span<bst_float const> point = preds.subspan(idx * nclass, nclass);
|
||||
|
||||
// Part of Softmax function
|
||||
bst_float wmax = std::numeric_limits<bst_float>::min();
|
||||
for (auto const i : point) { wmax = fmaxf(i, wmax); }
|
||||
double wsum = 0.0f;
|
||||
for (auto const i : point) { wsum += expf(i - wmax); }
|
||||
auto label = labels[idx];
|
||||
if (label < 0 || label >= nclass) {
|
||||
_label_correct[0] = 0;
|
||||
label = 0;
|
||||
}
|
||||
bst_float wt = is_null_weight ? 1.0f : weights[idx];
|
||||
for (int k = 0; k < nclass; ++k) {
|
||||
// Computation duplicated to avoid creating a cache.
|
||||
bst_float p = expf(point[k] - wmax) / static_cast<float>(wsum);
|
||||
const bst_float h = fmax(2.0f * p * (1.0f - p) * wt, kRtEps);
|
||||
p = label == k ? p - 1.0f : p;
|
||||
gpair[idx * nclass + k] = GradientPair(p * wt, h);
|
||||
}
|
||||
}, common::Range{0, ndata}, devices_, false)
|
||||
.Eval(out_gpair, &info.labels_, &preds, &info.weights_, &label_correct_);
|
||||
|
||||
out_gpair->Reshard(GPUSet::Empty());
|
||||
out_gpair->Reshard(GPUDistribution::Block(devices_));
|
||||
preds.Reshard(GPUSet::Empty());
|
||||
preds.Reshard(GPUDistribution::Block(devices_));
|
||||
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag != 1) {
|
||||
LOG(FATAL) << "SoftmaxMultiClassObj: label must be in [0, num_class).";
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
this->Transform(io_preds, output_prob_);
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float>* io_preds) override {
|
||||
this->Transform(io_preds, true);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "merror";
|
||||
}
|
||||
|
||||
inline void Transform(HostDeviceVector<bst_float> *io_preds, bool prob) {
|
||||
const int nclass = param_.num_class;
|
||||
const auto ndata = static_cast<int64_t>(io_preds->Size() / nclass);
|
||||
max_preds_.Resize(ndata);
|
||||
|
||||
io_preds->Reshard(GPUSet::Empty()); // clear out device memory
|
||||
if (prob) {
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
common::Span<bst_float> point =
|
||||
_preds.subspan(_idx * nclass, nclass);
|
||||
common::Softmax(point.begin(), point.end());
|
||||
},
|
||||
common::Range{0, ndata}, GPUDistribution::Granular(devices_, nclass))
|
||||
.Eval(io_preds);
|
||||
} else {
|
||||
io_preds->Reshard(GPUDistribution::Granular(devices_, nclass));
|
||||
max_preds_.Reshard(GPUDistribution::Block(devices_));
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<bst_float> _max_preds) {
|
||||
common::Span<const bst_float> point =
|
||||
_preds.subspan(_idx * nclass, nclass);
|
||||
_max_preds[_idx] =
|
||||
common::FindMaxIndex(point.cbegin(),
|
||||
point.cend()) - point.cbegin();
|
||||
},
|
||||
common::Range{0, ndata}, devices_, false)
|
||||
.Eval(io_preds, &max_preds_);
|
||||
}
|
||||
if (!prob) {
|
||||
io_preds->Resize(max_preds_.Size());
|
||||
io_preds->Copy(max_preds_);
|
||||
}
|
||||
io_preds->Reshard(GPUSet::Empty()); // clear out device memory
|
||||
io_preds->Reshard(GPUDistribution::Block(devices_));
|
||||
}
|
||||
|
||||
private:
|
||||
// output probability
|
||||
bool output_prob_;
|
||||
// parameter
|
||||
SoftmaxMultiClassParam param_;
|
||||
GPUSet devices_;
|
||||
// Cache for max_preds
|
||||
HostDeviceVector<bst_float> max_preds_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(SoftmaxMultiClassParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SoftmaxMultiClass, "multi:softmax")
|
||||
.describe("Softmax for multi-class classification, output class index.")
|
||||
.set_body([]() { return new SoftmaxMultiClassObj(false); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(SoftprobMultiClass, "multi:softprob")
|
||||
.describe("Softmax for multi-class classification, output probability distribution.")
|
||||
.set_body([]() { return new SoftmaxMultiClassObj(true); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@ -30,12 +30,15 @@ ObjFunction* ObjFunction::Create(const std::string& name) {
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
// List of files that will be force linked in static links.
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
#ifdef XGBOOST_USE_CUDA
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
|
||||
#endif
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu);
|
||||
#else
|
||||
DMLC_REGISTRY_LINK_TAG(regression_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(multiclass_obj);
|
||||
#endif
|
||||
DMLC_REGISTRY_LINK_TAG(rank_obj);
|
||||
DMLC_REGISTRY_LINK_TAG(hinge);
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
@ -1,426 +1,18 @@
|
||||
/*!
|
||||
* Copyright 2015 by Contributors
|
||||
* \file regression_obj.cc
|
||||
* \brief Definition of single-value regression and classification objectives.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "../common/math.h"
|
||||
#include "../common/avx_helpers.h"
|
||||
#include "./regression_loss.h"
|
||||
|
||||
// Dummy file to keep the CUDA conditional compile trick.
|
||||
|
||||
#include <dmlc/registry.h>
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj);
|
||||
|
||||
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
||||
float scale_pos_weight;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(RegLossParam) {
|
||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||
.describe("Scale the weight of positive examples by this factor");
|
||||
}
|
||||
};
|
||||
|
||||
// regression loss function
|
||||
template <typename Loss>
|
||||
class RegLossObj : public ObjFunction {
|
||||
public:
|
||||
RegLossObj() = default;
|
||||
|
||||
void Configure(
|
||||
const std::vector<std::pair<std::string, std::string> > &args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds, const MetaInfo &info,
|
||||
int iter, HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size()
|
||||
<< ", label.size=" << info.labels_.Size();
|
||||
const auto& preds_h = preds.HostVector();
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
const auto& weights = info.weights_.HostVector();
|
||||
|
||||
this->LazyCheckLabels(labels);
|
||||
out_gpair->Resize(preds_h.size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const auto n = static_cast<omp_ulong>(preds_h.size());
|
||||
auto gpair_ptr = out_gpair->HostPointer();
|
||||
avx::Float8 scale(param_.scale_pos_weight);
|
||||
|
||||
const omp_ulong remainder = n % 8;
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < n - remainder; i += 8) {
|
||||
avx::Float8 y(&labels[i]);
|
||||
avx::Float8 p = Loss::PredTransform(avx::Float8(&preds_h[i]));
|
||||
avx::Float8 w = weights.empty() ? avx::Float8(1.0f)
|
||||
: avx::Float8(&weights[i]);
|
||||
// Adjust weight
|
||||
w += y * (scale * w - w);
|
||||
avx::Float8 grad = Loss::FirstOrderGradient(p, y);
|
||||
avx::Float8 hess = Loss::SecondOrderGradient(p, y);
|
||||
avx::StoreGpair(gpair_ptr + i, grad * w, hess * w);
|
||||
}
|
||||
for (omp_ulong i = n - remainder; i < n; ++i) {
|
||||
auto y = labels[i];
|
||||
bst_float p = Loss::PredTransform(preds_h[i]);
|
||||
bst_float w = info.GetWeight(i);
|
||||
w += y * ((param_.scale_pos_weight * w) - w);
|
||||
gpair[i] = GradientPair(Loss::FirstOrderGradient(p, y) * w,
|
||||
Loss::SecondOrderGradient(p, y) * w);
|
||||
}
|
||||
}
|
||||
const char *DefaultEvalMetric() const override {
|
||||
return Loss::DefaultEvalMetric();
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const auto ndata = static_cast<bst_omp_uint>(preds.size());
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (bst_omp_uint j = 0; j < ndata; ++j) {
|
||||
preds[j] = Loss::PredTransform(preds[j]);
|
||||
}
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return Loss::ProbToMargin(base_score);
|
||||
}
|
||||
|
||||
protected:
|
||||
void LazyCheckLabels(const std::vector<float> &labels) {
|
||||
if (labels_checked_) return;
|
||||
for (auto &y : labels) {
|
||||
CHECK(Loss::CheckLabel(y)) << Loss::LabelErrorMsg();
|
||||
}
|
||||
labels_checked_ = true;
|
||||
}
|
||||
RegLossParam param_;
|
||||
bool labels_checked_{false};
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(RegLossParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
.describe("Linear regression.")
|
||||
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic")
|
||||
.describe("Logistic regression for probability regression task.")
|
||||
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic")
|
||||
.describe("Logistic regression for binary classification task.")
|
||||
.set_body([]() { return new RegLossObj<LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
|
||||
.describe("Logistic regression for classification, output score before logistic transformation")
|
||||
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
|
||||
|
||||
// declare parameter
|
||||
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
|
||||
float max_delta_step;
|
||||
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
|
||||
.describe("Maximum delta step we allow each weight estimation to be." \
|
||||
" This parameter is required for possion regression.");
|
||||
}
|
||||
};
|
||||
|
||||
// poisson regression for count
|
||||
class PoissonRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const auto& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds.Size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
bst_float p = preds_h[i];
|
||||
bst_float w = info.GetWeight(i);
|
||||
bst_float y = labels[i];
|
||||
if (y >= 0.0f) {
|
||||
gpair[i] = GradientPair((std::exp(p) - y) * w,
|
||||
std::exp(p + param_.max_delta_step) * w);
|
||||
} else {
|
||||
label_correct = false;
|
||||
}
|
||||
}
|
||||
CHECK(label_correct) << "PoissonRegression: label must be nonnegative";
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
|
||||
private:
|
||||
PoissonRegressionParam param_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(PoissonRegressionParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
||||
.describe("Possion regression for count data.")
|
||||
.set_body([]() { return new PoissonRegression(); });
|
||||
|
||||
// cox regression for survival data (negative values mean they are censored)
|
||||
class CoxRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {}
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const auto& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds_h.size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort();
|
||||
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
|
||||
// pre-compute a sum
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
exp_p_sum += std::exp(preds_h[label_order[i]]);
|
||||
}
|
||||
|
||||
// start calculating grad and hess
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
double r_k = 0;
|
||||
double s_k = 0;
|
||||
double last_exp_p = 0.0;
|
||||
double last_abs_y = 0.0;
|
||||
double accumulated_sum = 0;
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
const size_t ind = label_order[i];
|
||||
const double p = preds_h[ind];
|
||||
const double exp_p = std::exp(p);
|
||||
const double w = info.GetWeight(ind);
|
||||
const double y = labels[ind];
|
||||
const double abs_y = std::abs(y);
|
||||
|
||||
// only update the denominator after we move forward in time (labels are sorted)
|
||||
// this is Breslow's method for ties
|
||||
accumulated_sum += last_exp_p;
|
||||
if (last_abs_y < abs_y) {
|
||||
exp_p_sum -= accumulated_sum;
|
||||
accumulated_sum = 0;
|
||||
} else {
|
||||
CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " <<
|
||||
"MetaInfo::LabelArgsort failed!";
|
||||
}
|
||||
|
||||
if (y > 0) {
|
||||
r_k += 1.0/exp_p_sum;
|
||||
s_k += 1.0/(exp_p_sum*exp_p_sum);
|
||||
}
|
||||
|
||||
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
|
||||
const double hess = exp_p*r_k - exp_p*exp_p * s_k;
|
||||
gpair.at(ind) = GradientPair(grad * w, hess * w);
|
||||
|
||||
last_abs_y = abs_y;
|
||||
last_exp_p = exp_p;
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "cox-nloglik";
|
||||
}
|
||||
};
|
||||
|
||||
// register the objective function
|
||||
XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
|
||||
.describe("Cox regression for censored survival data (negative labels are considered censored).")
|
||||
.set_body([]() { return new CoxRegression(); });
|
||||
|
||||
// gamma regression
|
||||
class GammaRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const auto& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds_h.size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
bst_float p = preds_h[i];
|
||||
bst_float w = info.GetWeight(i);
|
||||
bst_float y = labels[i];
|
||||
if (y >= 0.0f) {
|
||||
gpair[i] = GradientPair((1 - y / std::exp(p)) * w, y / std::exp(p) * w);
|
||||
} else {
|
||||
label_correct = false;
|
||||
}
|
||||
}
|
||||
CHECK(label_correct) << "GammaRegression: label must be positive";
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
|
||||
.describe("Gamma regression for severity data.")
|
||||
.set_body([]() { return new GammaRegression(); });
|
||||
|
||||
// declare parameter
|
||||
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
|
||||
float tweedie_variance_power;
|
||||
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
|
||||
.describe("Tweedie variance power. Must be between in range [1, 2).");
|
||||
}
|
||||
};
|
||||
|
||||
// tweedie regression
|
||||
class TweedieRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const auto& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds.Size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
// check if label in range
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.Size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
bst_float p = preds_h[i];
|
||||
bst_float w = info.GetWeight(i);
|
||||
bst_float y = labels[i];
|
||||
float rho = param_.tweedie_variance_power;
|
||||
if (y >= 0.0f) {
|
||||
bst_float grad = -y * std::exp((1 - rho) * p) + std::exp((2 - rho) * p);
|
||||
bst_float hess = -y * (1 - rho) * \
|
||||
std::exp((1 - rho) * p) + (2 - rho) * std::exp((2 - rho) * p);
|
||||
gpair[i] = GradientPair(grad * w, hess * w);
|
||||
} else {
|
||||
label_correct = false;
|
||||
}
|
||||
}
|
||||
CHECK(label_correct) << "TweedieRegression: label must be nonnegative";
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override {
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
|
||||
std::string metric = os.str();
|
||||
return metric.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
TweedieRegressionParam param_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(TweedieRegressionParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
|
||||
.describe("Tweedie regression for insurance data.")
|
||||
.set_body([]() { return new TweedieRegression(); });
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
|
||||
#ifndef XGBOOST_USE_CUDA
|
||||
#include "regression_obj.cu"
|
||||
#endif
|
||||
|
||||
560
src/objective/regression_obj.cu
Normal file
560
src/objective/regression_obj.cu
Normal file
@ -0,0 +1,560 @@
|
||||
/*!
|
||||
* Copyright 2015-2018 by Contributors
|
||||
* \file regression_obj.cu
|
||||
* \brief Definition of single-value regression and classification objectives.
|
||||
* \author Tianqi Chen, Kailong Chen
|
||||
*/
|
||||
|
||||
#include <dmlc/omp.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/span.h"
|
||||
#include "../common/transform.h"
|
||||
#include "../common/common.h"
|
||||
#include "../common/host_device_vector.h"
|
||||
#include "./regression_loss.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
#if defined(XGBOOST_USE_CUDA)
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||
#endif
|
||||
|
||||
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
||||
float scale_pos_weight;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(RegLossParam) {
|
||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||
.describe("Scale the weight of positive examples by this factor");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Loss>
|
||||
class RegLossObj : public ObjFunction {
|
||||
protected:
|
||||
HostDeviceVector<int> label_correct_;
|
||||
|
||||
public:
|
||||
RegLossObj() = default;
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
|
||||
size_t ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
auto scale_pos_weight = param_.scale_pos_weight;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
bst_float p = Loss::PredTransform(_preds[_idx]);
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float label = _labels[_idx];
|
||||
if (label == 1.0f) {
|
||||
w *= scale_pos_weight;
|
||||
}
|
||||
if (!Loss::CheckLabel(label)) {
|
||||
// If there is an incorrect label, the host code will know.
|
||||
_label_correct[0] = 0;
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair(Loss::FirstOrderGradient(p, label) * w,
|
||||
Loss::SecondOrderGradient(p, label) * w);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
|
||||
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
|
||||
// copy "label correct" flags back to host
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return Loss::DefaultEvalMetric();
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<float> _preds) {
|
||||
_preds[_idx] = Loss::PredTransform(_preds[_idx]);
|
||||
}, common::Range{0, static_cast<int64_t>(io_preds->Size())},
|
||||
devices_).Eval(io_preds);
|
||||
}
|
||||
|
||||
float ProbToMargin(float base_score) const override {
|
||||
return Loss::ProbToMargin(base_score);
|
||||
}
|
||||
|
||||
protected:
|
||||
RegLossParam param_;
|
||||
GPUSet devices_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(RegLossParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LinearRegression, "reg:linear")
|
||||
.describe("Linear regression.")
|
||||
.set_body([]() { return new RegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRegression, "reg:logistic")
|
||||
.describe("Logistic regression for probability regression task.")
|
||||
.set_body([]() { return new RegLossObj<LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticClassification, "binary:logistic")
|
||||
.describe("Logistic regression for binary classification task.")
|
||||
.set_body([]() { return new RegLossObj<LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
|
||||
.describe("Logistic regression for classification, output score "
|
||||
"before logistic transformation.")
|
||||
.set_body([]() { return new RegLossObj<LogisticRaw>(); });
|
||||
|
||||
// Deprecated GPU functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
|
||||
.describe("Deprecated. Linear regression (computed on GPU).")
|
||||
.set_body([]() {
|
||||
LOG(WARNING) << "gpu:reg:linear is now deprecated, use reg:linear instead.";
|
||||
return new RegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic")
|
||||
.describe("Deprecated. Logistic regression for probability regression task (computed on GPU).")
|
||||
.set_body([]() {
|
||||
LOG(WARNING) << "gpu:reg:logistic is now deprecated, use reg:logistic instead.";
|
||||
return new RegLossObj<LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic")
|
||||
.describe("Deprecated. Logistic regression for binary classification task (computed on GPU).")
|
||||
.set_body([]() {
|
||||
LOG(WARNING) << "gpu:binary:logistic is now deprecated, use binary:logistic instead.";
|
||||
return new RegLossObj<LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw")
|
||||
.describe("Deprecated. Logistic regression for classification, output score "
|
||||
"before logistic transformation (computed on GPU)")
|
||||
.set_body([]() {
|
||||
LOG(WARNING) << "gpu:binary:logitraw is now deprecated, use binary:logitraw instead.";
|
||||
return new RegLossObj<LogisticRaw>(); });
|
||||
// End deprecated
|
||||
|
||||
// declare parameter
|
||||
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
|
||||
float max_delta_step;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
|
||||
.describe("Maximum delta step we allow each weight estimation to be." \
|
||||
" This parameter is required for possion regression.");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
// poisson regression for count
|
||||
class PoissonRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
size_t ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
bool is_null_weight = info.weights_.Size() == 0;
|
||||
bst_float max_delta_step = param_.max_delta_step;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
bst_float p = _preds[_idx];
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float y = _labels[_idx];
|
||||
if (y < 0.0f) {
|
||||
_label_correct[0] = 0;
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair{(expf(p) - y) * w,
|
||||
expf(p + max_delta_step) * w};
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
|
||||
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
// copy "label correct" flags back to host
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << "PoissonRegression: label must be nonnegative";
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = expf(_preds[_idx]);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(io_preds->Size())}, devices_)
|
||||
.Eval(io_preds);
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "poisson-nloglik";
|
||||
}
|
||||
|
||||
private:
|
||||
GPUSet devices_;
|
||||
PoissonRegressionParam param_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(PoissonRegressionParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(PoissonRegression, "count:poisson")
|
||||
.describe("Possion regression for count data.")
|
||||
.set_body([]() { return new PoissonRegression(); });
|
||||
|
||||
|
||||
// cox regression for survival data (negative values mean they are censored)
|
||||
class CoxRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {}
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const auto& preds_h = preds.HostVector();
|
||||
out_gpair->Resize(preds_h.size());
|
||||
auto& gpair = out_gpair->HostVector();
|
||||
const std::vector<size_t> &label_order = info.LabelAbsSort();
|
||||
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds_h.size()); // NOLINT(*)
|
||||
|
||||
// pre-compute a sum
|
||||
double exp_p_sum = 0; // we use double because we might need the precision with large datasets
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
exp_p_sum += std::exp(preds_h[label_order[i]]);
|
||||
}
|
||||
|
||||
// start calculating grad and hess
|
||||
const auto& labels = info.labels_.HostVector();
|
||||
double r_k = 0;
|
||||
double s_k = 0;
|
||||
double last_exp_p = 0.0;
|
||||
double last_abs_y = 0.0;
|
||||
double accumulated_sum = 0;
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
const size_t ind = label_order[i];
|
||||
const double p = preds_h[ind];
|
||||
const double exp_p = std::exp(p);
|
||||
const double w = info.GetWeight(ind);
|
||||
const double y = labels[ind];
|
||||
const double abs_y = std::abs(y);
|
||||
|
||||
// only update the denominator after we move forward in time (labels are sorted)
|
||||
// this is Breslow's method for ties
|
||||
accumulated_sum += last_exp_p;
|
||||
if (last_abs_y < abs_y) {
|
||||
exp_p_sum -= accumulated_sum;
|
||||
accumulated_sum = 0;
|
||||
} else {
|
||||
CHECK(last_abs_y <= abs_y) << "CoxRegression: labels must be in sorted order, " <<
|
||||
"MetaInfo::LabelArgsort failed!";
|
||||
}
|
||||
|
||||
if (y > 0) {
|
||||
r_k += 1.0/exp_p_sum;
|
||||
s_k += 1.0/(exp_p_sum*exp_p_sum);
|
||||
}
|
||||
|
||||
const double grad = exp_p*r_k - static_cast<bst_float>(y > 0);
|
||||
const double hess = exp_p*r_k - exp_p*exp_p * s_k;
|
||||
gpair.at(ind) = GradientPair(grad * w, hess * w);
|
||||
|
||||
last_abs_y = abs_y;
|
||||
last_exp_p = exp_p;
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = io_preds->HostVector();
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "cox-nloglik";
|
||||
}
|
||||
};
|
||||
|
||||
// register the objective function
|
||||
XGBOOST_REGISTER_OBJECTIVE(CoxRegression, "survival:cox")
|
||||
.describe("Cox regression for censored survival data (negative labels are considered censored).")
|
||||
.set_body([]() { return new CoxRegression(); });
|
||||
|
||||
|
||||
struct GammaRegressionParam : public dmlc::Parameter<GammaRegressionParam> {
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(GammaRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
// gamma regression
|
||||
class GammaRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const size_t ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
bst_float p = _preds[_idx];
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float y = _labels[_idx];
|
||||
if (y < 0.0f) {
|
||||
_label_correct[0] = 0;
|
||||
}
|
||||
_out_gpair[_idx] = GradientPair((1 - y / expf(p)) * w, y / expf(p) * w);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata)}, devices_).Eval(
|
||||
&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
|
||||
// copy "label correct" flags back to host
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << "GammaRegression: label must be nonnegative";
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = expf(_preds[_idx]);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(io_preds->Size())}, devices_)
|
||||
.Eval(io_preds);
|
||||
}
|
||||
void EvalTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
PredTransform(io_preds);
|
||||
}
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return "gamma-nloglik";
|
||||
}
|
||||
|
||||
private:
|
||||
GPUSet devices_;
|
||||
GammaRegressionParam param_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(GammaRegressionParam);
|
||||
// register the objective functions
|
||||
XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
|
||||
.describe("Gamma regression for severity data.")
|
||||
.set_body([]() { return new GammaRegression(); });
|
||||
|
||||
|
||||
// declare parameter
|
||||
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
|
||||
float tweedie_variance_power;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
|
||||
.describe("Tweedie variance power. Must be between in range [1, 2).");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(-1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms.");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
// tweedie regression
|
||||
class TweedieRegression : public ObjFunction {
|
||||
public:
|
||||
// declare functions
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device"; // Default is -1
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Resize(devices_.IsEmpty() ? 1 : devices_.Size());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<bst_float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair> *out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size()) << "labels are not correctly provided";
|
||||
const size_t ndata = preds.Size();
|
||||
out_gpair->Resize(ndata);
|
||||
label_correct_.Fill(1);
|
||||
|
||||
const bool is_null_weight = info.weights_.Size() == 0;
|
||||
const float rho = param_.tweedie_variance_power;
|
||||
common::Transform<>::Init(
|
||||
[=] XGBOOST_DEVICE(size_t _idx,
|
||||
common::Span<int> _label_correct,
|
||||
common::Span<GradientPair> _out_gpair,
|
||||
common::Span<const bst_float> _preds,
|
||||
common::Span<const bst_float> _labels,
|
||||
common::Span<const bst_float> _weights) {
|
||||
bst_float p = _preds[_idx];
|
||||
bst_float w = is_null_weight ? 1.0f : _weights[_idx];
|
||||
bst_float y = _labels[_idx];
|
||||
if (y < 0.0f) {
|
||||
_label_correct[0] = 0;
|
||||
}
|
||||
bst_float grad = -y * expf((1 - rho) * p) + expf((2 - rho) * p);
|
||||
bst_float hess =
|
||||
-y * (1 - rho) * \
|
||||
std::exp((1 - rho) * p) + (2 - rho) * expf((2 - rho) * p);
|
||||
_out_gpair[_idx] = GradientPair(grad * w, hess * w);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(ndata), 1}, devices_)
|
||||
.Eval(&label_correct_, out_gpair, &preds, &info.labels_, &info.weights_);
|
||||
|
||||
// copy "label correct" flags back to host
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (auto const flag : label_correct_h) {
|
||||
if (flag == 0) {
|
||||
LOG(FATAL) << "TweedieRegression: label must be nonnegative";
|
||||
}
|
||||
}
|
||||
}
|
||||
void PredTransform(HostDeviceVector<bst_float> *io_preds) override {
|
||||
common::Transform<>::Init(
|
||||
[] XGBOOST_DEVICE(size_t _idx, common::Span<bst_float> _preds) {
|
||||
_preds[_idx] = expf(_preds[_idx]);
|
||||
},
|
||||
common::Range{0, static_cast<int64_t>(io_preds->Size())}, devices_)
|
||||
.Eval(io_preds);
|
||||
}
|
||||
|
||||
bst_float ProbToMargin(bst_float base_score) const override {
|
||||
return std::log(base_score);
|
||||
}
|
||||
|
||||
const char* DefaultEvalMetric() const override {
|
||||
std::ostringstream os;
|
||||
os << "tweedie-nloglik@" << param_.tweedie_variance_power;
|
||||
std::string metric = os.str();
|
||||
return metric.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
GPUSet devices_;
|
||||
TweedieRegressionParam param_;
|
||||
HostDeviceVector<int> label_correct_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(TweedieRegressionParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(TweedieRegression, "reg:tweedie")
|
||||
.describe("Tweedie regression for insurance data.")
|
||||
.set_body([]() { return new TweedieRegression(); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@ -1,202 +0,0 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
// GPU implementation of objective function.
|
||||
// Necessary to avoid extra copying of data to CPU.
|
||||
#include <dmlc/omp.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/copy.h>
|
||||
#include <xgboost/logging.h>
|
||||
#include <xgboost/objective.h>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "../common/span.h"
|
||||
#include "../common/device_helpers.cuh"
|
||||
#include "../common/host_device_vector.h"
|
||||
#include "./regression_loss.h"
|
||||
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
|
||||
using dh::DVec;
|
||||
|
||||
DMLC_REGISTRY_FILE_TAG(regression_obj_gpu);
|
||||
|
||||
struct GPURegLossParam : public dmlc::Parameter<GPURegLossParam> {
|
||||
float scale_pos_weight;
|
||||
int n_gpus;
|
||||
int gpu_id;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(GPURegLossParam) {
|
||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||
.describe("Scale the weight of positive examples by this factor");
|
||||
DMLC_DECLARE_FIELD(n_gpus).set_default(1).set_lower_bound(-1)
|
||||
.describe("Number of GPUs to use for multi-gpu algorithms (NOT IMPLEMENTED)");
|
||||
DMLC_DECLARE_FIELD(gpu_id)
|
||||
.set_lower_bound(0)
|
||||
.set_default(0)
|
||||
.describe("gpu to use for objective function evaluation");
|
||||
}
|
||||
};
|
||||
|
||||
// GPU kernel for gradient computation
|
||||
template<typename Loss>
|
||||
__global__ void get_gradient_k
|
||||
(common::Span<GradientPair> out_gpair, common::Span<int> label_correct,
|
||||
common::Span<const float> preds, common::Span<const float> labels,
|
||||
const float * __restrict__ weights, int n, float scale_pos_weight) {
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= n)
|
||||
return;
|
||||
float p = Loss::PredTransform(preds[i]);
|
||||
float w = weights == nullptr ? 1.0f : weights[i];
|
||||
float label = labels[i];
|
||||
if (label == 1.0f)
|
||||
w *= scale_pos_weight;
|
||||
if (!Loss::CheckLabel(label))
|
||||
atomicAnd(label_correct.data(), 0);
|
||||
out_gpair[i] = GradientPair
|
||||
(Loss::FirstOrderGradient(p, label) * w, Loss::SecondOrderGradient(p, label) * w);
|
||||
}
|
||||
|
||||
// GPU kernel for predicate transformation
|
||||
template<typename Loss>
|
||||
__global__ void pred_transform_k(common::Span<float> preds, int n) {
|
||||
int i = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (i >= n)
|
||||
return;
|
||||
preds[i] = Loss::PredTransform(preds[i]);
|
||||
}
|
||||
|
||||
// regression loss function for evaluation on GPU (eventually)
|
||||
template<typename Loss>
|
||||
class GPURegLossObj : public ObjFunction {
|
||||
protected:
|
||||
HostDeviceVector<int> label_correct_;
|
||||
|
||||
// allocate device data for n elements, do nothing if memory is allocated already
|
||||
void LazyResize() {
|
||||
}
|
||||
|
||||
public:
|
||||
GPURegLossObj() {}
|
||||
|
||||
void Configure(const std::vector<std::pair<std::string, std::string> >& args) override {
|
||||
param_.InitAllowUnknown(args);
|
||||
CHECK(param_.n_gpus != 0) << "Must have at least one device";
|
||||
devices_ = GPUSet::All(param_.n_gpus).Normalised(param_.gpu_id);
|
||||
label_correct_.Reshard(devices_);
|
||||
label_correct_.Resize(devices_.Size());
|
||||
}
|
||||
|
||||
void GetGradient(const HostDeviceVector<float> &preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) override {
|
||||
CHECK_NE(info.labels_.Size(), 0U) << "label set cannot be empty";
|
||||
CHECK_EQ(preds.Size(), info.labels_.Size())
|
||||
<< "labels are not correctly provided"
|
||||
<< "preds.size=" << preds.Size() << ", label.size=" << info.labels_.Size();
|
||||
size_t ndata = preds.Size();
|
||||
preds.Reshard(devices_);
|
||||
info.labels_.Reshard(devices_);
|
||||
info.weights_.Reshard(devices_);
|
||||
out_gpair->Reshard(devices_);
|
||||
out_gpair->Resize(ndata);
|
||||
GetGradientDevice(preds, info, iter, out_gpair);
|
||||
}
|
||||
|
||||
private:
|
||||
void GetGradientDevice(const HostDeviceVector<float>& preds,
|
||||
const MetaInfo &info,
|
||||
int iter,
|
||||
HostDeviceVector<GradientPair>* out_gpair) {
|
||||
label_correct_.Fill(1);
|
||||
|
||||
// run the kernel
|
||||
#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
|
||||
for (int i = 0; i < devices_.Size(); ++i) {
|
||||
int d = devices_[i];
|
||||
dh::safe_cuda(cudaSetDevice(d));
|
||||
const int block = 256;
|
||||
size_t n = preds.DeviceSize(d);
|
||||
if (n > 0) {
|
||||
get_gradient_k<Loss><<<dh::DivRoundUp(n, block), block>>>
|
||||
(out_gpair->DeviceSpan(d), label_correct_.DeviceSpan(d),
|
||||
preds.DeviceSpan(d), info.labels_.DeviceSpan(d),
|
||||
info.weights_.Size() > 0 ? info.weights_.DevicePointer(d) : nullptr,
|
||||
n, param_.scale_pos_weight);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
|
||||
// copy "label correct" flags back to host
|
||||
std::vector<int>& label_correct_h = label_correct_.HostVector();
|
||||
for (int i = 0; i < devices_.Size(); ++i) {
|
||||
if (label_correct_h[i] == 0)
|
||||
LOG(FATAL) << Loss::LabelErrorMsg();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
const char* DefaultEvalMetric() const override {
|
||||
return Loss::DefaultEvalMetric();
|
||||
}
|
||||
|
||||
void PredTransform(HostDeviceVector<float> *io_preds) override {
|
||||
io_preds->Reshard(devices_);
|
||||
size_t ndata = io_preds->Size();
|
||||
PredTransformDevice(io_preds);
|
||||
}
|
||||
|
||||
void PredTransformDevice(HostDeviceVector<float>* preds) {
|
||||
#pragma omp parallel for schedule(static, 1) if (devices_.Size() > 1)
|
||||
for (int i = 0; i < devices_.Size(); ++i) {
|
||||
int d = devices_[i];
|
||||
dh::safe_cuda(cudaSetDevice(d));
|
||||
const int block = 256;
|
||||
size_t n = preds->DeviceSize(d);
|
||||
if (n > 0) {
|
||||
pred_transform_k<Loss><<<dh::DivRoundUp(n, block), block>>>(
|
||||
preds->DeviceSpan(d), n);
|
||||
dh::safe_cuda(cudaGetLastError());
|
||||
}
|
||||
dh::safe_cuda(cudaDeviceSynchronize());
|
||||
}
|
||||
}
|
||||
|
||||
float ProbToMargin(float base_score) const override {
|
||||
return Loss::ProbToMargin(base_score);
|
||||
}
|
||||
|
||||
protected:
|
||||
GPURegLossParam param_;
|
||||
GPUSet devices_;
|
||||
};
|
||||
|
||||
// register the objective functions
|
||||
DMLC_REGISTER_PARAMETER(GPURegLossParam);
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULinearRegression, "gpu:reg:linear")
|
||||
.describe("Linear regression (computed on GPU).")
|
||||
.set_body([]() { return new GPURegLossObj<LinearSquareLoss>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRegression, "gpu:reg:logistic")
|
||||
.describe("Logistic regression for probability regression task (computed on GPU).")
|
||||
.set_body([]() { return new GPURegLossObj<LogisticRegression>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticClassification, "gpu:binary:logistic")
|
||||
.describe("Logistic regression for binary classification task (computed on GPU).")
|
||||
.set_body([]() { return new GPURegLossObj<LogisticClassification>(); });
|
||||
|
||||
XGBOOST_REGISTER_OBJECTIVE(GPULogisticRaw, "gpu:binary:logitraw")
|
||||
.describe("Logistic regression for classification, output score "
|
||||
"before logistic transformation (computed on GPU)")
|
||||
.set_body([]() { return new GPURegLossObj<LogisticRaw>(); });
|
||||
|
||||
} // namespace obj
|
||||
} // namespace xgboost
|
||||
@ -14,7 +14,7 @@ struct WriteSymbolFunction {
|
||||
WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d,
|
||||
int* input_data_d)
|
||||
: cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {}
|
||||
|
||||
|
||||
__device__ void operator()(size_t i) {
|
||||
cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i);
|
||||
}
|
||||
@ -28,7 +28,7 @@ struct ReadSymbolFunction {
|
||||
|
||||
__device__ void operator()(size_t i) {
|
||||
output_data_d[i] = ci[i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST(CompressedIterator, TestGPU) {
|
||||
|
||||
@ -10,7 +10,7 @@
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
TEST(gpu_hist_util, TestDeviceSketch) {
|
||||
void TestDeviceSketch(const GPUSet& devices) {
|
||||
// create the data
|
||||
int nrows = 10001;
|
||||
std::vector<float> test_data(nrows);
|
||||
@ -28,7 +28,7 @@ TEST(gpu_hist_util, TestDeviceSketch) {
|
||||
tree::TrainParam p;
|
||||
p.max_bin = 20;
|
||||
p.gpu_id = 0;
|
||||
p.n_gpus = GPUSet::AllVisible().Size();
|
||||
p.n_gpus = devices.Size();
|
||||
// ensure that the exact quantiles are found
|
||||
p.gpu_batch_nrows = nrows * 10;
|
||||
|
||||
@ -54,5 +54,17 @@ TEST(gpu_hist_util, TestDeviceSketch) {
|
||||
delete dmat;
|
||||
}
|
||||
|
||||
TEST(gpu_hist_util, DeviceSketch) {
|
||||
TestDeviceSketch(GPUSet::Range(0, 1));
|
||||
}
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
TEST(gpu_hist_util, MGPU_DeviceSketch) {
|
||||
auto devices = GPUSet::AllVisible();
|
||||
CHECK_GT(devices.Size(), 1);
|
||||
TestDeviceSketch(devices);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -178,18 +178,57 @@ TEST(HostDeviceVector, TestCopy) {
|
||||
SetCudaSetDeviceHandler(nullptr);
|
||||
}
|
||||
|
||||
// The test is not really useful if n_gpus < 2
|
||||
TEST(HostDeviceVector, Reshard) {
|
||||
std::vector<int> h_vec (2345);
|
||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||
h_vec[i] = i;
|
||||
}
|
||||
HostDeviceVector<int> vec (h_vec);
|
||||
auto devices = GPUSet::Range(0, 1);
|
||||
|
||||
vec.Reshard(devices);
|
||||
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
auto span = vec.DeviceSpan(0); // sync to device
|
||||
|
||||
vec.Reshard(GPUSet::Empty()); // pull back to cpu, empty devices.
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
ASSERT_TRUE(vec.Devices().IsEmpty());
|
||||
|
||||
auto h_vec_1 = vec.HostVector();
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, Span) {
|
||||
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
vec.Reshard(GPUSet{0, 1});
|
||||
auto span = vec.DeviceSpan(0);
|
||||
ASSERT_EQ(vec.DeviceSize(0), span.size());
|
||||
ASSERT_EQ(vec.DevicePointer(0), span.data());
|
||||
auto const_span = vec.ConstDeviceSpan(0);
|
||||
ASSERT_EQ(vec.DeviceSize(0), span.size());
|
||||
ASSERT_EQ(vec.ConstDevicePointer(0), span.data());
|
||||
}
|
||||
|
||||
// Multi-GPUs' test
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
TEST(HostDeviceVector, MGPU_Reshard) {
|
||||
auto devices = GPUSet::AllVisible();
|
||||
if (devices.Size() < 2) {
|
||||
LOG(WARNING) << "Not testing in multi-gpu environment.";
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<int> h_vec (2345);
|
||||
for (size_t i = 0; i < h_vec.size(); ++i) {
|
||||
h_vec[i] = i;
|
||||
}
|
||||
HostDeviceVector<int> vec (h_vec);
|
||||
|
||||
// Data size for each device.
|
||||
std::vector<size_t> devices_size (devices.Size());
|
||||
|
||||
// From CPU to GPUs.
|
||||
// Assuming we have > 1 devices.
|
||||
vec.Reshard(devices);
|
||||
size_t total_size = 0;
|
||||
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||
@ -198,42 +237,26 @@ TEST(HostDeviceVector, Reshard) {
|
||||
}
|
||||
ASSERT_EQ(total_size, h_vec.size());
|
||||
ASSERT_EQ(total_size, vec.Size());
|
||||
auto h_vec_1 = vec.HostVector();
|
||||
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
vec.Reshard(GPUSet::Empty()); // clear out devices memory
|
||||
// Reshard from devices to devices with different distribution.
|
||||
EXPECT_ANY_THROW(
|
||||
vec.Reshard(GPUDistribution::Granular(devices, 12)));
|
||||
|
||||
// Shrink down the number of devices.
|
||||
vec.Reshard(GPUSet::Range(0, 1));
|
||||
// All data is drawn back to CPU
|
||||
vec.Reshard(GPUSet::Empty());
|
||||
ASSERT_TRUE(vec.Devices().IsEmpty());
|
||||
ASSERT_EQ(vec.Size(), h_vec.size());
|
||||
ASSERT_EQ(vec.DeviceSize(0), h_vec.size());
|
||||
h_vec_1 = vec.HostVector();
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
vec.Reshard(GPUSet::Empty()); // clear out devices memory
|
||||
|
||||
// Grow the number of devices.
|
||||
vec.Reshard(devices);
|
||||
vec.Reshard(GPUDistribution::Granular(devices, 12));
|
||||
total_size = 0;
|
||||
for (size_t i = 0; i < devices.Size(); ++i) {
|
||||
total_size += vec.DeviceSize(i);
|
||||
ASSERT_EQ(devices_size[i], vec.DeviceSize(i));
|
||||
devices_size[i] = vec.DeviceSize(i);
|
||||
}
|
||||
ASSERT_EQ(total_size, h_vec.size());
|
||||
ASSERT_EQ(total_size, vec.Size());
|
||||
h_vec_1 = vec.HostVector();
|
||||
ASSERT_TRUE(std::equal(h_vec_1.cbegin(), h_vec_1.cend(), h_vec.cbegin()));
|
||||
}
|
||||
|
||||
TEST(HostDeviceVector, Span) {
|
||||
HostDeviceVector<float> vec {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
vec.Reshard(GPUSet{0, 1});
|
||||
auto span = vec.DeviceSpan(0);
|
||||
ASSERT_EQ(vec.Size(), span.size());
|
||||
ASSERT_EQ(vec.DevicePointer(0), span.data());
|
||||
auto const_span = vec.ConstDeviceSpan(0);
|
||||
ASSERT_EQ(vec.Size(), span.size());
|
||||
ASSERT_EQ(vec.ConstDevicePointer(0), span.data());
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
|
||||
@ -7,6 +7,14 @@
|
||||
#include "../../include/xgboost/base.h"
|
||||
#include "../../../src/common/span.h"
|
||||
|
||||
template <typename Iter>
|
||||
XGBOOST_DEVICE void InitializeRange(Iter _begin, Iter _end) {
|
||||
float j = 0;
|
||||
for (Iter i = _begin; i != _end; ++i, ++j) {
|
||||
*i = j;
|
||||
}
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
@ -20,14 +28,6 @@ namespace common {
|
||||
*(status) = -1; \
|
||||
}
|
||||
|
||||
template <typename Iter>
|
||||
XGBOOST_DEVICE void InitializeRange(Iter _begin, Iter _end) {
|
||||
float j = 0;
|
||||
for (Iter i = _begin; i != _end; ++i, ++j) {
|
||||
*i = j;
|
||||
}
|
||||
}
|
||||
|
||||
struct TestTestStatus {
|
||||
int * status_;
|
||||
|
||||
|
||||
61
tests/cpp/common/test_transform_range.cc
Normal file
61
tests/cpp/common/test_transform_range.cc
Normal file
@ -0,0 +1,61 @@
|
||||
#include <xgboost/base.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
|
||||
#include "../../../src/common/host_device_vector.h"
|
||||
#include "../../../src/common/transform.h"
|
||||
#include "../../../src/common/span.h"
|
||||
#include "../helpers.h"
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
|
||||
#define TRANSFORM_GPU_RANGE GPUSet::Range(0, 1)
|
||||
#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Range(0, 1))
|
||||
|
||||
#else
|
||||
|
||||
#define TRANSFORM_GPU_RANGE GPUSet::Empty()
|
||||
#define TRANSFORM_GPU_DIST GPUDistribution::Block(GPUSet::Empty())
|
||||
|
||||
#endif
|
||||
|
||||
template <typename Iter>
|
||||
void InitializeRange(Iter _begin, Iter _end) {
|
||||
float j = 0;
|
||||
for (Iter i = _begin; i != _end; ++i, ++j) {
|
||||
*i = j;
|
||||
}
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
template <typename T>
|
||||
struct TestTransformRange {
|
||||
void XGBOOST_DEVICE operator()(size_t _idx,
|
||||
Span<bst_float> _out, Span<const bst_float> _in) {
|
||||
_out[_idx] = _in[_idx];
|
||||
}
|
||||
};
|
||||
|
||||
TEST(Transform, DeclareUnifiedTest(Basic)) {
|
||||
const size_t size {256};
|
||||
std::vector<bst_float> h_in(size);
|
||||
std::vector<bst_float> h_out(size);
|
||||
InitializeRange(h_in.begin(), h_in.end());
|
||||
std::vector<bst_float> h_sol(size);
|
||||
InitializeRange(h_sol.begin(), h_sol.end());
|
||||
|
||||
const HostDeviceVector<bst_float> in_vec{h_in, TRANSFORM_GPU_DIST};
|
||||
HostDeviceVector<bst_float> out_vec{h_out, TRANSFORM_GPU_DIST};
|
||||
out_vec.Fill(0);
|
||||
|
||||
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, TRANSFORM_GPU_RANGE)
|
||||
.Eval(&out_vec, &in_vec);
|
||||
std::vector<bst_float> res = out_vec.HostVector();
|
||||
|
||||
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace xgboost
|
||||
43
tests/cpp/common/test_transform_range.cu
Normal file
43
tests/cpp/common/test_transform_range.cu
Normal file
@ -0,0 +1,43 @@
|
||||
// This converts all tests from CPU to GPU.
|
||||
#include "test_transform_range.cc"
|
||||
|
||||
#if defined(XGBOOST_USE_NCCL)
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
|
||||
// Test here is multi gpu specific
|
||||
TEST(Transform, MGPU_Basic) {
|
||||
auto devices = GPUSet::AllVisible();
|
||||
CHECK_GT(devices.Size(), 1);
|
||||
const size_t size {256};
|
||||
std::vector<bst_float> h_in(size);
|
||||
std::vector<bst_float> h_out(size);
|
||||
InitializeRange(h_in.begin(), h_in.end());
|
||||
std::vector<bst_float> h_sol(size);
|
||||
InitializeRange(h_sol.begin(), h_sol.end());
|
||||
|
||||
const HostDeviceVector<bst_float> in_vec {h_in,
|
||||
GPUDistribution::Block(GPUSet::Empty())};
|
||||
HostDeviceVector<bst_float> out_vec {h_out,
|
||||
GPUDistribution::Block(GPUSet::Empty())};
|
||||
out_vec.Fill(0);
|
||||
|
||||
in_vec.Reshard(GPUDistribution::Granular(devices, 8));
|
||||
out_vec.Reshard(GPUDistribution::Block(devices));
|
||||
|
||||
// Granularity is different, resharding will throw.
|
||||
EXPECT_ANY_THROW(
|
||||
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size}, devices)
|
||||
.Eval(&out_vec, &in_vec));
|
||||
|
||||
|
||||
Transform<>::Init(TestTransformRange<bst_float>{}, Range{0, size},
|
||||
devices, false).Eval(&out_vec, &in_vec);
|
||||
std::vector<bst_float> res = out_vec.HostVector();
|
||||
|
||||
ASSERT_TRUE(std::equal(h_sol.begin(), h_sol.end(), res.begin()));
|
||||
}
|
||||
|
||||
} // namespace xgboost
|
||||
} // namespace common
|
||||
#endif
|
||||
@ -125,3 +125,17 @@ std::shared_ptr<xgboost::DMatrix>* CreateDMatrix(int rows, int columns,
|
||||
&handle);
|
||||
return static_cast<std::shared_ptr<xgboost::DMatrix> *>(handle);
|
||||
}
|
||||
|
||||
namespace xgboost {
|
||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _beg2) {
|
||||
for (auto iter1 = _beg1, iter2 = _beg2; iter1 != _end1; ++iter1, ++iter2) {
|
||||
if (std::abs(*iter1 - *iter2) > xgboost::kRtEps){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -15,6 +15,12 @@
|
||||
#include <xgboost/objective.h>
|
||||
#include <xgboost/metric.h>
|
||||
|
||||
#if defined(__CUDACC__)
|
||||
#define DeclareUnifiedTest(name) GPU ## name
|
||||
#else
|
||||
#define DeclareUnifiedTest(name) name
|
||||
#endif
|
||||
|
||||
std::string TempFileName();
|
||||
|
||||
bool FileExists(const std::string name);
|
||||
@ -46,6 +52,12 @@ xgboost::bst_float GetMetricEval(
|
||||
std::vector<xgboost::bst_float> labels,
|
||||
std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float> ());
|
||||
|
||||
namespace xgboost {
|
||||
bool IsNear(std::vector<xgboost::bst_float>::const_iterator _beg1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _end1,
|
||||
std::vector<xgboost::bst_float>::const_iterator _beg2);
|
||||
}
|
||||
|
||||
/**
|
||||
* \fn std::shared_ptr<xgboost::DMatrix> CreateDMatrix(int rows, int columns, float sparsity, int seed);
|
||||
*
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Objective, HingeObj) {
|
||||
TEST(Objective, DeclareUnifiedTest(HingeObj)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:hinge");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -15,6 +15,12 @@ TEST(Objective, HingeObj) {
|
||||
{ 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
|
||||
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
|
||||
CheckObjFunction(obj,
|
||||
{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f},
|
||||
{ 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f, 1.0f},
|
||||
{}, // Empty weight.
|
||||
{ 0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f},
|
||||
{ eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps });
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
|
||||
|
||||
1
tests/cpp/objective/test_hinge.cu
Normal file
1
tests/cpp/objective/test_hinge.cu
Normal file
@ -0,0 +1 @@
|
||||
#include "test_hinge.cc"
|
||||
60
tests/cpp/objective/test_multiclass_obj.cc
Normal file
60
tests/cpp/objective/test_multiclass_obj.cc
Normal file
@ -0,0 +1,60 @@
|
||||
/*!
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassObjGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax");
|
||||
std::vector<std::pair<std::string, std::string>> args {{"num_class", "3"}};
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{1, 0, 2, 2, 0, 1}, // preds
|
||||
{1.0, 0.0}, // labels
|
||||
{1.0, 1.0}, // weights
|
||||
{0.24f, -0.91f, 0.66f, -0.33f, 0.09f, 0.24f}, // grad
|
||||
{0.36, 0.16, 0.44, 0.45, 0.16, 0.37}); // hess
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(SoftmaxMultiClassBasic)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softmax");
|
||||
std::vector<std::pair<std::string, std::string>> args
|
||||
{std::pair<std::string, std::string>("num_class", "3")};
|
||||
obj->Configure(args);
|
||||
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> io_preds = {2.0f, 0.0f, 1.0f,
|
||||
1.0f, 0.0f, 2.0f};
|
||||
std::vector<xgboost::bst_float> out_preds = {0.0f, 2.0f};
|
||||
obj->PredTransform(&io_preds);
|
||||
|
||||
auto& preds = io_preds.HostVector();
|
||||
|
||||
for (int i = 0; i < static_cast<int>(io_preds.Size()); ++i) {
|
||||
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
|
||||
}
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, DeclareUnifiedTest(SoftprobMultiClassBasic)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("multi:softprob");
|
||||
std::vector<std::pair<std::string, std::string>> args
|
||||
{std::pair<std::string, std::string>("num_class", "3")};
|
||||
obj->Configure(args);
|
||||
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> io_preds = {2.0f, 0.0f, 1.0f};
|
||||
std::vector<xgboost::bst_float> out_preds = {0.66524096f, 0.09003057f, 0.24472847f};
|
||||
|
||||
obj->PredTransform(&io_preds);
|
||||
auto& preds = io_preds.HostVector();
|
||||
|
||||
for (int i = 0; i < static_cast<int>(io_preds.Size()); ++i) {
|
||||
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
|
||||
}
|
||||
delete obj;
|
||||
}
|
||||
1
tests/cpp/objective/test_multiclass_obj_gpu.cu
Normal file
1
tests/cpp/objective/test_multiclass_obj_gpu.cu
Normal file
@ -0,0 +1 @@
|
||||
#include "test_multiclass_obj.cc"
|
||||
@ -1,9 +1,11 @@
|
||||
// Copyright by Contributors
|
||||
/*!
|
||||
* Copyright 2017-2018 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Objective, LinearRegressionGPair) {
|
||||
TEST(Objective, DeclareUnifiedTest(LinearRegressionGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:linear");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -13,27 +15,32 @@ TEST(Objective, LinearRegressionGPair) {
|
||||
{1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
CheckObjFunction(obj,
|
||||
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{}, // empty weight
|
||||
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1});
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, LogisticRegressionGPair) {
|
||||
TEST(Objective, DeclareUnifiedTest(LogisticRegressionGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
|
||||
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1}, // preds
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1}, // labels
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1}, // weights
|
||||
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f}, // out_grad
|
||||
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f}); // out_hess
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, LogisticRegressionBasic) {
|
||||
TEST(Objective, DeclareUnifiedTest(LogisticRegressionBasic)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:logistic");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -61,7 +68,7 @@ TEST(Objective, LogisticRegressionBasic) {
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, LogisticRawGPair) {
|
||||
TEST(Objective, DeclareUnifiedTest(LogisticRawGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("binary:logitraw");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -75,7 +82,7 @@ TEST(Objective, LogisticRawGPair) {
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, PoissonRegressionGPair) {
|
||||
TEST(Objective, DeclareUnifiedTest(PoissonRegressionGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
args.push_back(std::make_pair("max_delta_step", "0.1f"));
|
||||
@ -86,11 +93,16 @@ TEST(Objective, PoissonRegressionGPair) {
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f},
|
||||
{1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f});
|
||||
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{}, // Empty weight
|
||||
{ 1, 1.10f, 2.45f, 2.71f, 0, 0.10f, 1.45f, 1.71f},
|
||||
{1.10f, 1.22f, 2.71f, 3.00f, 1.10f, 1.22f, 2.71f, 3.00f});
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, PoissonRegressionBasic) {
|
||||
TEST(Objective, DeclareUnifiedTest(PoissonRegressionBasic)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("count:poisson");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -116,7 +128,7 @@ TEST(Objective, PoissonRegressionBasic) {
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, GammaRegressionGPair) {
|
||||
TEST(Objective, DeclareUnifiedTest(GammaRegressionGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -126,11 +138,16 @@ TEST(Objective, GammaRegressionGPair) {
|
||||
{1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f},
|
||||
{0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f});
|
||||
|
||||
CheckObjFunction(obj,
|
||||
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{}, // Empty weight
|
||||
{1, 1, 1, 1, 0, 0.09f, 0.59f, 0.63f},
|
||||
{0, 0, 0, 0, 1, 0.90f, 0.40f, 0.36f});
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, GammaRegressionBasic) {
|
||||
TEST(Objective, DeclareUnifiedTest(GammaRegressionBasic)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:gamma");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -156,7 +173,7 @@ TEST(Objective, GammaRegressionBasic) {
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, TweedieRegressionGPair) {
|
||||
TEST(Objective, DeclareUnifiedTest(TweedieRegressionGPair)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
args.push_back(std::make_pair("tweedie_variance_power", "1.1f"));
|
||||
@ -167,11 +184,17 @@ TEST(Objective, TweedieRegressionGPair) {
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f},
|
||||
{0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f});
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{}, // Empty weight.
|
||||
{ 1, 1.09f, 2.24f, 2.45f, 0, 0.10f, 1.33f, 1.55f},
|
||||
{0.89f, 0.98f, 2.02f, 2.21f, 1, 1.08f, 2.11f, 2.30f});
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, TweedieRegressionBasic) {
|
||||
TEST(Objective, DeclareUnifiedTest(TweedieRegressionBasic)) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("reg:tweedie");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
@ -197,6 +220,9 @@ TEST(Objective, TweedieRegressionBasic) {
|
||||
delete obj;
|
||||
}
|
||||
|
||||
|
||||
// CoxRegression not implemented in GPU code, no need for testing.
|
||||
#if !defined(__CUDACC__)
|
||||
TEST(Objective, CoxRegressionGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("survival:cox");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
@ -210,3 +236,4 @@ TEST(Objective, CoxRegressionGPair) {
|
||||
|
||||
delete obj;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -1,78 +1,6 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
* Copyright 2018 XGBoost contributors
|
||||
*/
|
||||
#include <xgboost/objective.h>
|
||||
// Dummy file to keep the CUDA tests.
|
||||
|
||||
#include "../helpers.h"
|
||||
|
||||
TEST(Objective, GPULinearRegressionGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:linear");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{0, 0.1f, 0.9f, 1.0f, -1.0f, -0.9f, -0.1f, 0},
|
||||
{1, 1, 1, 1, 1, 1, 1, 1});
|
||||
|
||||
ASSERT_NO_THROW(obj->DefaultEvalMetric());
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, GPULogisticRegressionGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
|
||||
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, GPULogisticRegressionBasic) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:reg:logistic");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
|
||||
// test label validation
|
||||
EXPECT_ANY_THROW(CheckObjFunction(obj, {0}, {10}, {1}, {0}, {0}))
|
||||
<< "Expected error when label not in range [0,1f] for LogisticRegression";
|
||||
|
||||
// test ProbToMargin
|
||||
EXPECT_NEAR(obj->ProbToMargin(0.1f), -2.197f, 0.01f);
|
||||
EXPECT_NEAR(obj->ProbToMargin(0.5f), 0, 0.01f);
|
||||
EXPECT_NEAR(obj->ProbToMargin(0.9f), 2.197f, 0.01f);
|
||||
EXPECT_ANY_THROW(obj->ProbToMargin(10))
|
||||
<< "Expected error when base_score not in range [0,1f] for LogisticRegression";
|
||||
|
||||
// test PredTransform
|
||||
xgboost::HostDeviceVector<xgboost::bst_float> io_preds = {0, 0.1f, 0.5f, 0.9f, 1};
|
||||
std::vector<xgboost::bst_float> out_preds = {0.5f, 0.524f, 0.622f, 0.710f, 0.731f};
|
||||
obj->PredTransform(&io_preds);
|
||||
auto& preds = io_preds.HostVector();
|
||||
for (int i = 0; i < static_cast<int>(io_preds.Size()); ++i) {
|
||||
EXPECT_NEAR(preds[i], out_preds[i], 0.01f);
|
||||
}
|
||||
|
||||
delete obj;
|
||||
}
|
||||
|
||||
TEST(Objective, GPULogisticRawGPair) {
|
||||
xgboost::ObjFunction * obj = xgboost::ObjFunction::Create("gpu:binary:logitraw");
|
||||
std::vector<std::pair<std::string, std::string> > args;
|
||||
obj->Configure(args);
|
||||
CheckObjFunction(obj,
|
||||
{ 0, 0.1f, 0.9f, 1, 0, 0.1f, 0.9f, 1},
|
||||
{ 0, 0, 0, 0, 1, 1, 1, 1},
|
||||
{ 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
{ 0.5f, 0.52f, 0.71f, 0.73f, -0.5f, -0.47f, -0.28f, -0.26f},
|
||||
{0.25f, 0.24f, 0.20f, 0.19f, 0.25f, 0.24f, 0.20f, 0.19f});
|
||||
|
||||
delete obj;
|
||||
}
|
||||
#include "test_regression_obj.cc"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user