AVX gradients (#2878)
* AVX gradients * Add google test for AVX * Create fallback implementation, remove fma instruction * Improved accuracy of AVX exp function
This commit is contained in:
parent
3dcf966bc3
commit
24f527a1c0
@ -9,6 +9,7 @@ msvc_use_static_runtime()
|
||||
|
||||
# Options
|
||||
option(USE_CUDA "Build with GPU acceleration")
|
||||
option(USE_AVX "Build with AVX instructions. May not produce identical results due to approximate math." OFF)
|
||||
option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA")
|
||||
option(JVM_BINDINGS "Build JVM bindings" OFF)
|
||||
option(GOOGLE_TEST "Build google tests" OFF)
|
||||
@ -39,6 +40,17 @@ else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funroll-loops")
|
||||
endif()
|
||||
|
||||
# AVX
|
||||
if(USE_AVX)
|
||||
if(MSVC)
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx")
|
||||
endif()
|
||||
add_definitions(-DXGBOOST_USE_AVX)
|
||||
endif()
|
||||
|
||||
|
||||
# compiled code customizations for R package
|
||||
if(R_LIB)
|
||||
add_definitions(
|
||||
@ -212,12 +224,12 @@ endif()
|
||||
|
||||
# Test
|
||||
if(GOOGLE_TEST)
|
||||
enable_testing()
|
||||
find_package(GTest REQUIRED)
|
||||
enable_testing()
|
||||
|
||||
file(GLOB_RECURSE TEST_SOURCES "tests/cpp/*.cc")
|
||||
auto_source_group("${TEST_SOURCES}")
|
||||
include_directories(${GTEST_INCLUDE_DIRS})
|
||||
include_directories(${GTEST_INCLUDE_DIR})
|
||||
|
||||
if(USE_CUDA)
|
||||
file(GLOB_RECURSE CUDA_TEST_SOURCES "tests/cpp/*.cu")
|
||||
@ -228,7 +240,7 @@ if(GOOGLE_TEST)
|
||||
|
||||
add_executable(testxgboost ${TEST_SOURCES} ${CUDA_TEST_OBJS} $<TARGET_OBJECTS:objxgboost>)
|
||||
set_output_directory(testxgboost ${PROJECT_SOURCE_DIR})
|
||||
target_link_libraries(testxgboost ${GTEST_BOTH_LIBRARIES} ${LINK_LIBRARIES})
|
||||
target_link_libraries(testxgboost ${GTEST_LIBRARIES} ${LINK_LIBRARIES})
|
||||
|
||||
add_test(TestXGBoost testxgboost)
|
||||
endif()
|
||||
|
||||
79
cmake/modules/FindGTest.cmake
Normal file
79
cmake/modules/FindGTest.cmake
Normal file
@ -0,0 +1,79 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Tries to find GTest headers and libraries.
|
||||
#
|
||||
# Usage of this module as follows:
|
||||
#
|
||||
# find_package(GTest)
|
||||
#
|
||||
# Variables used by this module, they can change the default behaviour and need
|
||||
# to be set before calling find_package:
|
||||
#
|
||||
# GTest_HOME - When set, this path is inspected instead of standard library
|
||||
# locations as the root of the GTest installation.
|
||||
# The environment variable GTEST_HOME overrides this veriable.
|
||||
#
|
||||
# This module defines
|
||||
# GTEST_INCLUDE_DIR, directory containing headers
|
||||
# GTEST_LIBS, directory containing gtest libraries
|
||||
# GTEST_STATIC_LIB, path to libgtest.a
|
||||
# GTEST_SHARED_LIB, path to libgtest's shared library
|
||||
# GTEST_FOUND, whether gtest has been found
|
||||
|
||||
find_path(GTEST_INCLUDE_DIR NAMES gtest/gtest.h gtest.h PATHS ${CMAKE_SOURCE_DIR}/gtest/include NO_DEFAULT_PATH)
|
||||
find_library(GTEST_LIBRARIES NAMES gtest PATHS ${CMAKE_SOURCE_DIR}/gtest/lib NO_DEFAULT_PATH)
|
||||
|
||||
if (GTEST_INCLUDE_DIR )
|
||||
message(STATUS "Found the GTest includes: ${GTEST_INCLUDE_DIR}")
|
||||
endif ()
|
||||
|
||||
|
||||
if (GTEST_INCLUDE_DIR AND GTEST_LIBRARIES)
|
||||
set(GTEST_FOUND TRUE)
|
||||
get_filename_component( GTEST_LIBS ${GTEST_LIBRARIES} PATH )
|
||||
set(GTEST_LIB_NAME gtest)
|
||||
set(GTEST_STATIC_LIB ${GTEST_LIBS}/${CMAKE_STATIC_LIBRARY_PREFIX}${GTEST_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
set(GTEST_MAIN_STATIC_LIB ${GTEST_LIBS}/${CMAKE_STATIC_LIBRARY_PREFIX}${GTEST_LIB_NAME}_main${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
set(GTEST_SHARED_LIB ${GTEST_LIBS}/${CMAKE_SHARED_LIBRARY_PREFIX}${GTEST_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX})
|
||||
else ()
|
||||
set(GTEST_FOUND FALSE)
|
||||
endif ()
|
||||
|
||||
if (GTEST_FOUND)
|
||||
if (NOT GTest_FIND_QUIETLY)
|
||||
message(STATUS "Found the GTest library: ${GTEST_LIBRARIES}")
|
||||
endif ()
|
||||
else ()
|
||||
if (NOT GTest_FIND_QUIETLY)
|
||||
set(GTEST_ERR_MSG "Could not find the GTest library. Looked in ")
|
||||
if ( _gtest_roots )
|
||||
set(GTEST_ERR_MSG "${GTEST_ERR_MSG} in ${_gtest_roots}.")
|
||||
else ()
|
||||
set(GTEST_ERR_MSG "${GTEST_ERR_MSG} system search paths.")
|
||||
endif ()
|
||||
if (GTest_FIND_REQUIRED)
|
||||
message(FATAL_ERROR "${GTEST_ERR_MSG}")
|
||||
else (GTest_FIND_REQUIRED)
|
||||
message(STATUS "${GTEST_ERR_MSG}")
|
||||
endif (GTest_FIND_REQUIRED)
|
||||
endif ()
|
||||
endif ()
|
||||
|
||||
mark_as_advanced(
|
||||
GTEST_INCLUDE_DIR
|
||||
GTEST_LIBS
|
||||
GTEST_LIBRARIES
|
||||
GTEST_STATIC_LIB
|
||||
GTEST_SHARED_LIB
|
||||
)
|
||||
287
src/common/avx_helpers.h
Normal file
287
src/common/avx_helpers.h
Normal file
@ -0,0 +1,287 @@
|
||||
/*!
|
||||
* Copyright 2017 by Contributors
|
||||
* \author Rory Mitchell
|
||||
*/
|
||||
#pragma once
|
||||
#include <algorithm>
|
||||
#include "xgboost/base.h"
|
||||
|
||||
#ifdef XGBOOST_USE_AVX
|
||||
namespace avx {
|
||||
/**
|
||||
* \struct Float8
|
||||
*
|
||||
* \brief Helper class for processing a vector of eight floats using AVX
|
||||
* instructions. Implements basic math operators.
|
||||
*/
|
||||
|
||||
struct Float8 {
|
||||
__m256 x;
|
||||
explicit Float8(const __m256& x) : x(x) {}
|
||||
explicit Float8(const float& val) : x(_mm256_broadcast_ss(&val)) {}
|
||||
explicit Float8(const float* vec) : x(_mm256_loadu_ps(vec)) {}
|
||||
Float8() : x() {}
|
||||
Float8& operator+=(const Float8& rhs) {
|
||||
x = _mm256_add_ps(x, rhs.x);
|
||||
return *this;
|
||||
}
|
||||
Float8& operator-=(const Float8& rhs) {
|
||||
x = _mm256_sub_ps(x, rhs.x);
|
||||
return *this;
|
||||
}
|
||||
Float8& operator*=(const Float8& rhs) {
|
||||
x = _mm256_mul_ps(x, rhs.x);
|
||||
return *this;
|
||||
}
|
||||
Float8& operator/=(const Float8& rhs) {
|
||||
x = _mm256_div_ps(x, rhs.x);
|
||||
return *this;
|
||||
}
|
||||
void Print() {
|
||||
float* f = reinterpret_cast<float*>(&x);
|
||||
printf("%f %f %f %f %f %f %f %f\n", f[0], f[1], f[2], f[3], f[4], f[5],
|
||||
f[6], f[7]);
|
||||
}
|
||||
};
|
||||
|
||||
inline Float8 operator+(Float8 lhs, const Float8& rhs) {
|
||||
lhs += rhs;
|
||||
return lhs;
|
||||
}
|
||||
inline Float8 operator-(Float8 lhs, const Float8& rhs) {
|
||||
lhs -= rhs;
|
||||
return lhs;
|
||||
}
|
||||
inline Float8 operator*(Float8 lhs, const Float8& rhs) {
|
||||
lhs *= rhs;
|
||||
return lhs;
|
||||
}
|
||||
inline Float8 operator/(Float8 lhs, const Float8& rhs) {
|
||||
lhs /= rhs;
|
||||
return lhs;
|
||||
}
|
||||
|
||||
inline Float8 round(const Float8& x) {
|
||||
return Float8(_mm256_round_ps(x.x, _MM_FROUND_TO_NEAREST_INT));
|
||||
}
|
||||
} // namespace avx
|
||||
|
||||
// Overload std::max/min
|
||||
namespace std {
|
||||
inline avx::Float8 max(const avx::Float8& a, const avx::Float8& b) {
|
||||
return avx::Float8(_mm256_max_ps(a.x, b.x));
|
||||
}
|
||||
inline avx::Float8 min(const avx::Float8& a, const avx::Float8& b) {
|
||||
return avx::Float8(_mm256_min_ps(a.x, b.x));
|
||||
}
|
||||
} // namespace std
|
||||
|
||||
namespace avx {
|
||||
|
||||
// https://codingforspeed.com/using-faster-exponential-approximation/
|
||||
inline Float8 Exp4096(Float8 x) {
|
||||
x *= Float8(1.0f / 4096.0f);
|
||||
x += Float8(1.0f);
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
x *= x;
|
||||
return x;
|
||||
}
|
||||
|
||||
inline Float8 pow2n(Float8 const& n) {
|
||||
const float pow2_23 = 8388608.0; // 2^23
|
||||
const float bias = 127.0; // bias in exponent
|
||||
Float8 a =
|
||||
n + Float8(bias + pow2_23); // put n + bias in least significant bits
|
||||
__m256i b = _mm256_castps_si256(a.x);
|
||||
|
||||
// Do bit shift in SSE so we don't have to use AVX2 instructions
|
||||
__m128i c1 = _mm256_castsi256_si128(b);
|
||||
b = _mm256_permute2f128_si256(b, b, 1);
|
||||
__m128i c2 = _mm256_castsi256_si128(b);
|
||||
c1 = _mm_slli_epi32(c1, 23);
|
||||
c2 = _mm_slli_epi32(c2, 23);
|
||||
|
||||
__m256i c = _mm256_insertf128_si256(_mm256_castsi128_si256(c1), (c2), 0x1);
|
||||
return Float8(_mm256_castsi256_ps(c));
|
||||
}
|
||||
|
||||
inline Float8 polynomial_5(Float8 const& x, const float c0, const float c1,
|
||||
const float c2, const float c3, const float c4,
|
||||
const float c5) {
|
||||
// calculates polynomial c5*x^5 + c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
|
||||
Float8 x2 = x * x;
|
||||
Float8 x4 = x2 * x2;
|
||||
return (Float8(c2) + Float8(c3) * x) * x2 +
|
||||
((Float8(c4) + Float8(c5) * x) * x4 + (Float8(c0) + Float8(c1) * x));
|
||||
}
|
||||
|
||||
// AVX exp Function based off Agner Fog's vector library
|
||||
// https://github.com/darealshinji/vectorclass/blob/master/vectormath_exp.h
|
||||
// Modified so it doesn't require AVX2 instructions
|
||||
// Clamps input values to the range -87.3f, +87.3f
|
||||
inline Float8 ExpAgner(Float8 x) {
|
||||
// Clamp input values
|
||||
float max_x = 87.3f;
|
||||
x = std::min(x, Float8(max_x));
|
||||
x = std::max(x, Float8(-max_x));
|
||||
|
||||
// 1/log(2)
|
||||
const float log2e = 1.44269504088896340736f;
|
||||
|
||||
// Taylor coefficients
|
||||
const float P0expf = 1.f / 2.f;
|
||||
const float P1expf = 1.f / 6.f;
|
||||
const float P2expf = 1.f / 24.f;
|
||||
const float P3expf = 1.f / 120.f;
|
||||
const float P4expf = 1.f / 720.f;
|
||||
const float P5expf = 1.f / 5040.f;
|
||||
|
||||
const float ln2f_hi = 0.693359375f;
|
||||
const float ln2f_lo = -2.12194440e-4f;
|
||||
|
||||
Float8 r = round(x * Float8(log2e));
|
||||
x -= r * Float8(ln2f_hi);
|
||||
x -= r * Float8(ln2f_lo);
|
||||
|
||||
Float8 x2 = x * x;
|
||||
Float8 z = polynomial_5(x, P0expf, P1expf, P2expf, P3expf, P4expf, P5expf);
|
||||
z *= x2;
|
||||
z += x;
|
||||
|
||||
// multiply by power of 2
|
||||
Float8 n2 = pow2n(r);
|
||||
|
||||
z = (z + Float8(1.0f)) * n2;
|
||||
return z;
|
||||
}
|
||||
|
||||
inline Float8 Sigmoid(Float8 x) {
|
||||
Float8 exp = ExpAgner(x * Float8(-1.0f));
|
||||
x = Float8(1.0f) + exp;
|
||||
return Float8(_mm256_rcp_ps(x.x));
|
||||
}
|
||||
|
||||
// Store 8 gradient pairs given vectors containing gradient and Hessian
|
||||
inline void StoreGpair(xgboost::bst_gpair* dst, const Float8& grad,
|
||||
const Float8& hess) {
|
||||
float* ptr = reinterpret_cast<float*>(dst);
|
||||
__m256 gpair_low = _mm256_unpacklo_ps(grad.x, hess.x);
|
||||
__m256 gpair_high = _mm256_unpackhi_ps(grad.x, hess.x);
|
||||
_mm256_storeu_ps(ptr, _mm256_permute2f128_ps(gpair_low, gpair_high, 0x20));
|
||||
_mm256_storeu_ps(ptr + 8,
|
||||
_mm256_permute2f128_ps(gpair_low, gpair_high, 0x31));
|
||||
}
|
||||
} // namespace avx
|
||||
#else
|
||||
namespace avx {
|
||||
/**
|
||||
* \struct Float8
|
||||
*
|
||||
* \brief Fallback implementation not using AVX.
|
||||
*/
|
||||
|
||||
struct Float8 {
|
||||
float x[8];
|
||||
explicit Float8(const float& val) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
x[i] = val;
|
||||
}
|
||||
}
|
||||
explicit Float8(const float* vec) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
x[i] = vec[i];
|
||||
}
|
||||
}
|
||||
Float8() {}
|
||||
Float8& operator+=(const Float8& rhs) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
x[i] += rhs.x[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
Float8& operator-=(const Float8& rhs) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
x[i] -= rhs.x[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
Float8& operator*=(const Float8& rhs) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
x[i] *= rhs.x[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
Float8& operator/=(const Float8& rhs) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
x[i] /= rhs.x[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void Print() {
|
||||
float* f = reinterpret_cast<float*>(&x);
|
||||
printf("%f %f %f %f %f %f %f %f\n", f[0], f[1], f[2], f[3], f[4], f[5],
|
||||
f[6], f[7]);
|
||||
}
|
||||
};
|
||||
|
||||
inline Float8 operator+(Float8 lhs, const Float8& rhs) {
|
||||
lhs += rhs;
|
||||
return lhs;
|
||||
}
|
||||
inline Float8 operator-(Float8 lhs, const Float8& rhs) {
|
||||
lhs -= rhs;
|
||||
return lhs;
|
||||
}
|
||||
inline Float8 operator*(Float8 lhs, const Float8& rhs) {
|
||||
lhs *= rhs;
|
||||
return lhs;
|
||||
}
|
||||
inline Float8 operator/(Float8 lhs, const Float8& rhs) {
|
||||
lhs /= rhs;
|
||||
return lhs;
|
||||
}
|
||||
|
||||
// Store 8 gradient pairs given vectors containing gradient and Hessian
|
||||
inline void StoreGpair(xgboost::bst_gpair* dst, const Float8& grad,
|
||||
const Float8& hess) {
|
||||
for (int i = 0; i < 8; i++) {
|
||||
dst[i] = xgboost::bst_gpair(grad.x[i], hess.x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
inline Float8 Sigmoid(Float8 x) {
|
||||
Float8 sig;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
sig.x[i] = 1.0f / (1.0f + std::exp(-x.x[i]));
|
||||
}
|
||||
return sig;
|
||||
}
|
||||
} // namespace avx
|
||||
|
||||
namespace std {
|
||||
inline avx::Float8 max(const avx::Float8& a, const avx::Float8& b) {
|
||||
avx::Float8 max;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
max.x[i] = std::max(a.x[i], b.x[i]);
|
||||
}
|
||||
return max;
|
||||
}
|
||||
inline avx::Float8 min(const avx::Float8& a, const avx::Float8& b) {
|
||||
avx::Float8 min;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
min.x[i] = std::min(a.x[i], b.x[i]);
|
||||
}
|
||||
return min;
|
||||
}
|
||||
} // namespace std
|
||||
#endif
|
||||
@ -11,6 +11,7 @@
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include "avx_helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@ -23,6 +24,10 @@ inline float Sigmoid(float x) {
|
||||
return 1.0f / (1.0f + std::exp(-x));
|
||||
}
|
||||
|
||||
inline avx::Float8 Sigmoid(avx::Float8 x) {
|
||||
return avx::Sigmoid(x);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief do inplace softmax transformaton on p_rec
|
||||
* \param p_rec the input/output vector of the values.
|
||||
|
||||
@ -11,6 +11,7 @@
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include "../common/math.h"
|
||||
#include "../common/avx_helpers.h"
|
||||
|
||||
namespace xgboost {
|
||||
namespace obj {
|
||||
@ -20,22 +21,28 @@ DMLC_REGISTRY_FILE_TAG(regression_obj);
|
||||
// common regressions
|
||||
// linear regression
|
||||
struct LinearSquareLoss {
|
||||
static bst_float PredTransform(bst_float x) { return x; }
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return x; }
|
||||
static bool CheckLabel(bst_float x) { return true; }
|
||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; }
|
||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { return 1.0f; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) { return T(1.0f); }
|
||||
static bst_float ProbToMargin(bst_float base_score) { return base_score; }
|
||||
static const char* LabelErrorMsg() { return ""; }
|
||||
static const char* DefaultEvalMetric() { return "rmse"; }
|
||||
};
|
||||
// logistic loss for probability regression task
|
||||
struct LogisticRegression {
|
||||
static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); }
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return common::Sigmoid(x); }
|
||||
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
|
||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; }
|
||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||
const float eps = 1e-16f;
|
||||
return std::max(predt * (1.0f - predt), eps);
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) { return predt - label; }
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) {
|
||||
const T eps = T(1e-16f);
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static bst_float ProbToMargin(bst_float base_score) {
|
||||
CHECK(base_score > 0.0f && base_score < 1.0f)
|
||||
@ -53,28 +60,28 @@ struct LogisticClassification : public LogisticRegression {
|
||||
};
|
||||
// logistic loss, but predict un-transformed margin
|
||||
struct LogisticRaw : public LogisticRegression {
|
||||
static bst_float PredTransform(bst_float x) { return x; }
|
||||
static bst_float FirstOrderGradient(bst_float predt, bst_float label) {
|
||||
template <typename T>
|
||||
static T PredTransform(T x) { return x; }
|
||||
template <typename T>
|
||||
static T FirstOrderGradient(T predt, T label) {
|
||||
predt = common::Sigmoid(predt);
|
||||
return predt - label;
|
||||
}
|
||||
static bst_float SecondOrderGradient(bst_float predt, bst_float label) {
|
||||
const float eps = 1e-16f;
|
||||
template <typename T>
|
||||
static T SecondOrderGradient(T predt, T label) {
|
||||
const T eps = T(1e-16f);
|
||||
predt = common::Sigmoid(predt);
|
||||
return std::max(predt * (1.0f - predt), eps);
|
||||
return std::max(predt * (T(1.0f) - predt), eps);
|
||||
}
|
||||
static const char* DefaultEvalMetric() { return "auc"; }
|
||||
};
|
||||
|
||||
struct RegLossParam : public dmlc::Parameter<RegLossParam> {
|
||||
float scale_pos_weight;
|
||||
int nthread;
|
||||
// declare parameters
|
||||
DMLC_DECLARE_PARAMETER(RegLossParam) {
|
||||
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
|
||||
.describe("Scale the weight of positive examples by this factor");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
|
||||
"Number of threads to use.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -98,23 +105,36 @@ class RegLossObj : public ObjFunction {
|
||||
|
||||
this->LazyCheckLabels(info.labels);
|
||||
out_gpair->resize(preds.size());
|
||||
const omp_ulong n = static_cast<omp_ulong>(preds.size());
|
||||
auto gpair_ptr = out_gpair->data();
|
||||
avx::Float8 scale(param_.scale_pos_weight);
|
||||
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.size());
|
||||
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) {
|
||||
const omp_ulong remainder = n % 8;
|
||||
int nthread = omp_get_max_threads();
|
||||
// Use a maximum of 8 threads
|
||||
#pragma omp parallel for schedule(static) num_threads(std::min(8, nthread))
|
||||
for (int i = 0; i < n - remainder; i += 8) {
|
||||
avx::Float8 y(&info.labels[i]);
|
||||
avx::Float8 p = Loss::PredTransform(avx::Float8(&preds[i]));
|
||||
avx::Float8 w = info.weights.empty() ? avx::Float8(1.0f)
|
||||
: avx::Float8(&info.weights[i]);
|
||||
// Adjust weight
|
||||
w += y * (scale * w - w);
|
||||
avx::Float8 grad = Loss::FirstOrderGradient(p, y);
|
||||
avx::Float8 hess = Loss::SecondOrderGradient(p, y);
|
||||
avx::StoreGpair(gpair_ptr + i, grad * w, hess * w);
|
||||
}
|
||||
for (omp_ulong i = n - remainder; i < n; ++i) {
|
||||
auto y = info.labels[i];
|
||||
bst_float p = Loss::PredTransform(preds[i]);
|
||||
bst_float w = info.GetWeight(i);
|
||||
// Branchless version of the below function
|
||||
// The branch is particularly slow as the cpu cannot predict the label
|
||||
// with any accuracy resulting in frequent pipeline stalls
|
||||
// if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
|
||||
w += y * ((param_.scale_pos_weight * w) - w);
|
||||
(*out_gpair)[i] = bst_gpair(Loss::FirstOrderGradient(p, y) * w,
|
||||
Loss::SecondOrderGradient(p, y) * w);
|
||||
}
|
||||
|
||||
// Reset omp max threads
|
||||
omp_set_num_threads(nthread);
|
||||
}
|
||||
const char *DefaultEvalMetric() const override {
|
||||
return Loss::DefaultEvalMetric();
|
||||
@ -165,13 +185,10 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
|
||||
// declare parameter
|
||||
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
|
||||
float max_delta_step;
|
||||
int nthread;
|
||||
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
|
||||
.describe("Maximum delta step we allow each weight estimation to be." \
|
||||
" This parameter is required for possion regression.");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
|
||||
"Number of threads to use.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -194,8 +211,7 @@ class PoissonRegression : public ObjFunction {
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
|
||||
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
bst_float p = preds[i];
|
||||
bst_float w = info.GetWeight(i);
|
||||
@ -212,8 +228,7 @@ class PoissonRegression : public ObjFunction {
|
||||
void PredTransform(std::vector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = *io_preds;
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
@ -297,12 +312,9 @@ XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
|
||||
// declare parameter
|
||||
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
|
||||
float tweedie_variance_power;
|
||||
int nthread;
|
||||
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
|
||||
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
|
||||
.describe("Tweedie variance power. Must be between in range [1, 2).");
|
||||
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
|
||||
"Number of threads to use.");
|
||||
}
|
||||
};
|
||||
|
||||
@ -325,8 +337,7 @@ class TweedieRegression : public ObjFunction {
|
||||
bool label_correct = true;
|
||||
// start calculating gradient
|
||||
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
|
||||
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
|
||||
bst_float p = preds[i];
|
||||
bst_float w = info.GetWeight(i);
|
||||
@ -346,8 +357,7 @@ class TweedieRegression : public ObjFunction {
|
||||
void PredTransform(std::vector<bst_float> *io_preds) override {
|
||||
std::vector<bst_float> &preds = *io_preds;
|
||||
const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
|
||||
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread;
|
||||
#pragma omp parallel for schedule(static) num_threads(nthread)
|
||||
#pragma omp parallel for schedule(static)
|
||||
for (long j = 0; j < ndata; ++j) { // NOLINT(*)
|
||||
preds[j] = std::exp(preds[j]);
|
||||
}
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
/*!
|
||||
* Copyright 2017 XGBoost contributors
|
||||
*/
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/sequence.h>
|
||||
#include <xgboost/tree_updater.h>
|
||||
#include <algorithm>
|
||||
@ -564,7 +564,6 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
const RegTree& tree) {
|
||||
monitor.Start("InitDataOnce");
|
||||
if (!initialised) {
|
||||
CheckGradientMax(gpair);
|
||||
this->InitDataOnce(dmat);
|
||||
}
|
||||
monitor.Stop("InitDataOnce");
|
||||
@ -774,6 +773,9 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
|
||||
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
|
||||
RegTree* p_tree) {
|
||||
// Temporarily store number of threads so we can change it back later
|
||||
int nthread = omp_get_max_threads();
|
||||
|
||||
auto& tree = *p_tree;
|
||||
|
||||
monitor.Start("InitData");
|
||||
@ -819,6 +821,9 @@ class GPUHistMakerExperimental : public TreeUpdater {
|
||||
monitor.Stop("EvaluateSplits");
|
||||
}
|
||||
}
|
||||
|
||||
// Reset omp num threads
|
||||
omp_set_num_threads(nthread);
|
||||
}
|
||||
|
||||
struct ExpandEntry {
|
||||
|
||||
@ -91,10 +91,30 @@ if [ ${TASK} == "java_test" ]; then
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "cmake_test" ]; then
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ..
|
||||
set -e
|
||||
# Build gtest via cmake
|
||||
wget https://github.com/google/googletest/archive/release-1.7.0.zip
|
||||
unzip release-1.7.0.zip
|
||||
mv googletest-release-1.7.0 gtest && cd gtest
|
||||
cmake . && make
|
||||
mkdir lib && mv libgtest.a lib
|
||||
cd ..
|
||||
rm -rf release-1.7.0.zip
|
||||
|
||||
# Build/test without AVX
|
||||
mkdir build && cd build
|
||||
cmake .. -DGOOGLE_TEST=ON
|
||||
make
|
||||
cd ..
|
||||
./testxgboost
|
||||
rm -rf build
|
||||
|
||||
# Build/test with AVX
|
||||
mkdir build && cd build
|
||||
cmake .. -DGOOGLE_TEST=ON -DUSE_AVX=ON
|
||||
make
|
||||
cd ..
|
||||
./testxgboost
|
||||
fi
|
||||
|
||||
if [ ${TASK} == "cpp_test" ]; then
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user