AVX gradients (#2878)

* AVX gradients

* Add google test for AVX

* Create fallback implementation, remove fma instruction

* Improved accuracy of AVX exp function
This commit is contained in:
Rory Mitchell 2017-11-27 08:56:01 +13:00 committed by GitHub
parent 3dcf966bc3
commit 24f527a1c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 465 additions and 47 deletions

View File

@ -9,6 +9,7 @@ msvc_use_static_runtime()
# Options # Options
option(USE_CUDA "Build with GPU acceleration") option(USE_CUDA "Build with GPU acceleration")
option(USE_AVX "Build with AVX instructions. May not produce identical results due to approximate math." OFF)
option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA") option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA")
option(JVM_BINDINGS "Build JVM bindings" OFF) option(JVM_BINDINGS "Build JVM bindings" OFF)
option(GOOGLE_TEST "Build google tests" OFF) option(GOOGLE_TEST "Build google tests" OFF)
@ -39,6 +40,17 @@ else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funroll-loops") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funroll-loops")
endif() endif()
# AVX
if(USE_AVX)
if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx")
endif()
add_definitions(-DXGBOOST_USE_AVX)
endif()
# compiled code customizations for R package # compiled code customizations for R package
if(R_LIB) if(R_LIB)
add_definitions( add_definitions(
@ -212,12 +224,12 @@ endif()
# Test # Test
if(GOOGLE_TEST) if(GOOGLE_TEST)
enable_testing()
find_package(GTest REQUIRED) find_package(GTest REQUIRED)
enable_testing()
file(GLOB_RECURSE TEST_SOURCES "tests/cpp/*.cc") file(GLOB_RECURSE TEST_SOURCES "tests/cpp/*.cc")
auto_source_group("${TEST_SOURCES}") auto_source_group("${TEST_SOURCES}")
include_directories(${GTEST_INCLUDE_DIRS}) include_directories(${GTEST_INCLUDE_DIR})
if(USE_CUDA) if(USE_CUDA)
file(GLOB_RECURSE CUDA_TEST_SOURCES "tests/cpp/*.cu") file(GLOB_RECURSE CUDA_TEST_SOURCES "tests/cpp/*.cu")
@ -228,7 +240,7 @@ if(GOOGLE_TEST)
add_executable(testxgboost ${TEST_SOURCES} ${CUDA_TEST_OBJS} $<TARGET_OBJECTS:objxgboost>) add_executable(testxgboost ${TEST_SOURCES} ${CUDA_TEST_OBJS} $<TARGET_OBJECTS:objxgboost>)
set_output_directory(testxgboost ${PROJECT_SOURCE_DIR}) set_output_directory(testxgboost ${PROJECT_SOURCE_DIR})
target_link_libraries(testxgboost ${GTEST_BOTH_LIBRARIES} ${LINK_LIBRARIES}) target_link_libraries(testxgboost ${GTEST_LIBRARIES} ${LINK_LIBRARIES})
add_test(TestXGBoost testxgboost) add_test(TestXGBoost testxgboost)
endif() endif()

View File

@ -0,0 +1,79 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Tries to find GTest headers and libraries.
#
# Usage of this module as follows:
#
# find_package(GTest)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# GTest_HOME - When set, this path is inspected instead of standard library
# locations as the root of the GTest installation.
# The environment variable GTEST_HOME overrides this veriable.
#
# This module defines
# GTEST_INCLUDE_DIR, directory containing headers
# GTEST_LIBS, directory containing gtest libraries
# GTEST_STATIC_LIB, path to libgtest.a
# GTEST_SHARED_LIB, path to libgtest's shared library
# GTEST_FOUND, whether gtest has been found
find_path(GTEST_INCLUDE_DIR NAMES gtest/gtest.h gtest.h PATHS ${CMAKE_SOURCE_DIR}/gtest/include NO_DEFAULT_PATH)
find_library(GTEST_LIBRARIES NAMES gtest PATHS ${CMAKE_SOURCE_DIR}/gtest/lib NO_DEFAULT_PATH)
if (GTEST_INCLUDE_DIR )
message(STATUS "Found the GTest includes: ${GTEST_INCLUDE_DIR}")
endif ()
if (GTEST_INCLUDE_DIR AND GTEST_LIBRARIES)
set(GTEST_FOUND TRUE)
get_filename_component( GTEST_LIBS ${GTEST_LIBRARIES} PATH )
set(GTEST_LIB_NAME gtest)
set(GTEST_STATIC_LIB ${GTEST_LIBS}/${CMAKE_STATIC_LIBRARY_PREFIX}${GTEST_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX})
set(GTEST_MAIN_STATIC_LIB ${GTEST_LIBS}/${CMAKE_STATIC_LIBRARY_PREFIX}${GTEST_LIB_NAME}_main${CMAKE_STATIC_LIBRARY_SUFFIX})
set(GTEST_SHARED_LIB ${GTEST_LIBS}/${CMAKE_SHARED_LIBRARY_PREFIX}${GTEST_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX})
else ()
set(GTEST_FOUND FALSE)
endif ()
if (GTEST_FOUND)
if (NOT GTest_FIND_QUIETLY)
message(STATUS "Found the GTest library: ${GTEST_LIBRARIES}")
endif ()
else ()
if (NOT GTest_FIND_QUIETLY)
set(GTEST_ERR_MSG "Could not find the GTest library. Looked in ")
if ( _gtest_roots )
set(GTEST_ERR_MSG "${GTEST_ERR_MSG} in ${_gtest_roots}.")
else ()
set(GTEST_ERR_MSG "${GTEST_ERR_MSG} system search paths.")
endif ()
if (GTest_FIND_REQUIRED)
message(FATAL_ERROR "${GTEST_ERR_MSG}")
else (GTest_FIND_REQUIRED)
message(STATUS "${GTEST_ERR_MSG}")
endif (GTest_FIND_REQUIRED)
endif ()
endif ()
mark_as_advanced(
GTEST_INCLUDE_DIR
GTEST_LIBS
GTEST_LIBRARIES
GTEST_STATIC_LIB
GTEST_SHARED_LIB
)

287
src/common/avx_helpers.h Normal file
View File

@ -0,0 +1,287 @@
/*!
* Copyright 2017 by Contributors
* \author Rory Mitchell
*/
#pragma once
#include <algorithm>
#include "xgboost/base.h"
#ifdef XGBOOST_USE_AVX
namespace avx {
/**
* \struct Float8
*
* \brief Helper class for processing a vector of eight floats using AVX
* instructions. Implements basic math operators.
*/
struct Float8 {
__m256 x;
explicit Float8(const __m256& x) : x(x) {}
explicit Float8(const float& val) : x(_mm256_broadcast_ss(&val)) {}
explicit Float8(const float* vec) : x(_mm256_loadu_ps(vec)) {}
Float8() : x() {}
Float8& operator+=(const Float8& rhs) {
x = _mm256_add_ps(x, rhs.x);
return *this;
}
Float8& operator-=(const Float8& rhs) {
x = _mm256_sub_ps(x, rhs.x);
return *this;
}
Float8& operator*=(const Float8& rhs) {
x = _mm256_mul_ps(x, rhs.x);
return *this;
}
Float8& operator/=(const Float8& rhs) {
x = _mm256_div_ps(x, rhs.x);
return *this;
}
void Print() {
float* f = reinterpret_cast<float*>(&x);
printf("%f %f %f %f %f %f %f %f\n", f[0], f[1], f[2], f[3], f[4], f[5],
f[6], f[7]);
}
};
inline Float8 operator+(Float8 lhs, const Float8& rhs) {
lhs += rhs;
return lhs;
}
inline Float8 operator-(Float8 lhs, const Float8& rhs) {
lhs -= rhs;
return lhs;
}
inline Float8 operator*(Float8 lhs, const Float8& rhs) {
lhs *= rhs;
return lhs;
}
inline Float8 operator/(Float8 lhs, const Float8& rhs) {
lhs /= rhs;
return lhs;
}
inline Float8 round(const Float8& x) {
return Float8(_mm256_round_ps(x.x, _MM_FROUND_TO_NEAREST_INT));
}
} // namespace avx
// Overload std::max/min
namespace std {
inline avx::Float8 max(const avx::Float8& a, const avx::Float8& b) {
return avx::Float8(_mm256_max_ps(a.x, b.x));
}
inline avx::Float8 min(const avx::Float8& a, const avx::Float8& b) {
return avx::Float8(_mm256_min_ps(a.x, b.x));
}
} // namespace std
namespace avx {
// https://codingforspeed.com/using-faster-exponential-approximation/
inline Float8 Exp4096(Float8 x) {
x *= Float8(1.0f / 4096.0f);
x += Float8(1.0f);
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
x *= x;
return x;
}
inline Float8 pow2n(Float8 const& n) {
const float pow2_23 = 8388608.0; // 2^23
const float bias = 127.0; // bias in exponent
Float8 a =
n + Float8(bias + pow2_23); // put n + bias in least significant bits
__m256i b = _mm256_castps_si256(a.x);
// Do bit shift in SSE so we don't have to use AVX2 instructions
__m128i c1 = _mm256_castsi256_si128(b);
b = _mm256_permute2f128_si256(b, b, 1);
__m128i c2 = _mm256_castsi256_si128(b);
c1 = _mm_slli_epi32(c1, 23);
c2 = _mm_slli_epi32(c2, 23);
__m256i c = _mm256_insertf128_si256(_mm256_castsi128_si256(c1), (c2), 0x1);
return Float8(_mm256_castsi256_ps(c));
}
inline Float8 polynomial_5(Float8 const& x, const float c0, const float c1,
const float c2, const float c3, const float c4,
const float c5) {
// calculates polynomial c5*x^5 + c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0
Float8 x2 = x * x;
Float8 x4 = x2 * x2;
return (Float8(c2) + Float8(c3) * x) * x2 +
((Float8(c4) + Float8(c5) * x) * x4 + (Float8(c0) + Float8(c1) * x));
}
// AVX exp Function based off Agner Fog's vector library
// https://github.com/darealshinji/vectorclass/blob/master/vectormath_exp.h
// Modified so it doesn't require AVX2 instructions
// Clamps input values to the range -87.3f, +87.3f
inline Float8 ExpAgner(Float8 x) {
// Clamp input values
float max_x = 87.3f;
x = std::min(x, Float8(max_x));
x = std::max(x, Float8(-max_x));
// 1/log(2)
const float log2e = 1.44269504088896340736f;
// Taylor coefficients
const float P0expf = 1.f / 2.f;
const float P1expf = 1.f / 6.f;
const float P2expf = 1.f / 24.f;
const float P3expf = 1.f / 120.f;
const float P4expf = 1.f / 720.f;
const float P5expf = 1.f / 5040.f;
const float ln2f_hi = 0.693359375f;
const float ln2f_lo = -2.12194440e-4f;
Float8 r = round(x * Float8(log2e));
x -= r * Float8(ln2f_hi);
x -= r * Float8(ln2f_lo);
Float8 x2 = x * x;
Float8 z = polynomial_5(x, P0expf, P1expf, P2expf, P3expf, P4expf, P5expf);
z *= x2;
z += x;
// multiply by power of 2
Float8 n2 = pow2n(r);
z = (z + Float8(1.0f)) * n2;
return z;
}
inline Float8 Sigmoid(Float8 x) {
Float8 exp = ExpAgner(x * Float8(-1.0f));
x = Float8(1.0f) + exp;
return Float8(_mm256_rcp_ps(x.x));
}
// Store 8 gradient pairs given vectors containing gradient and Hessian
inline void StoreGpair(xgboost::bst_gpair* dst, const Float8& grad,
const Float8& hess) {
float* ptr = reinterpret_cast<float*>(dst);
__m256 gpair_low = _mm256_unpacklo_ps(grad.x, hess.x);
__m256 gpair_high = _mm256_unpackhi_ps(grad.x, hess.x);
_mm256_storeu_ps(ptr, _mm256_permute2f128_ps(gpair_low, gpair_high, 0x20));
_mm256_storeu_ps(ptr + 8,
_mm256_permute2f128_ps(gpair_low, gpair_high, 0x31));
}
} // namespace avx
#else
namespace avx {
/**
* \struct Float8
*
* \brief Fallback implementation not using AVX.
*/
struct Float8 {
float x[8];
explicit Float8(const float& val) {
for (int i = 0; i < 8; i++) {
x[i] = val;
}
}
explicit Float8(const float* vec) {
for (int i = 0; i < 8; i++) {
x[i] = vec[i];
}
}
Float8() {}
Float8& operator+=(const Float8& rhs) {
for (int i = 0; i < 8; i++) {
x[i] += rhs.x[i];
}
return *this;
}
Float8& operator-=(const Float8& rhs) {
for (int i = 0; i < 8; i++) {
x[i] -= rhs.x[i];
}
return *this;
}
Float8& operator*=(const Float8& rhs) {
for (int i = 0; i < 8; i++) {
x[i] *= rhs.x[i];
}
return *this;
}
Float8& operator/=(const Float8& rhs) {
for (int i = 0; i < 8; i++) {
x[i] /= rhs.x[i];
}
return *this;
}
void Print() {
float* f = reinterpret_cast<float*>(&x);
printf("%f %f %f %f %f %f %f %f\n", f[0], f[1], f[2], f[3], f[4], f[5],
f[6], f[7]);
}
};
inline Float8 operator+(Float8 lhs, const Float8& rhs) {
lhs += rhs;
return lhs;
}
inline Float8 operator-(Float8 lhs, const Float8& rhs) {
lhs -= rhs;
return lhs;
}
inline Float8 operator*(Float8 lhs, const Float8& rhs) {
lhs *= rhs;
return lhs;
}
inline Float8 operator/(Float8 lhs, const Float8& rhs) {
lhs /= rhs;
return lhs;
}
// Store 8 gradient pairs given vectors containing gradient and Hessian
inline void StoreGpair(xgboost::bst_gpair* dst, const Float8& grad,
const Float8& hess) {
for (int i = 0; i < 8; i++) {
dst[i] = xgboost::bst_gpair(grad.x[i], hess.x[i]);
}
}
inline Float8 Sigmoid(Float8 x) {
Float8 sig;
for (int i = 0; i < 8; i++) {
sig.x[i] = 1.0f / (1.0f + std::exp(-x.x[i]));
}
return sig;
}
} // namespace avx
namespace std {
inline avx::Float8 max(const avx::Float8& a, const avx::Float8& b) {
avx::Float8 max;
for (int i = 0; i < 8; i++) {
max.x[i] = std::max(a.x[i], b.x[i]);
}
return max;
}
inline avx::Float8 min(const avx::Float8& a, const avx::Float8& b) {
avx::Float8 min;
for (int i = 0; i < 8; i++) {
min.x[i] = std::min(a.x[i], b.x[i]);
}
return min;
}
} // namespace std
#endif

View File

@ -11,6 +11,7 @@
#include <vector> #include <vector>
#include <cmath> #include <cmath>
#include <algorithm> #include <algorithm>
#include "avx_helpers.h"
namespace xgboost { namespace xgboost {
namespace common { namespace common {
@ -23,6 +24,10 @@ inline float Sigmoid(float x) {
return 1.0f / (1.0f + std::exp(-x)); return 1.0f / (1.0f + std::exp(-x));
} }
inline avx::Float8 Sigmoid(avx::Float8 x) {
return avx::Sigmoid(x);
}
/*! /*!
* \brief do inplace softmax transformaton on p_rec * \brief do inplace softmax transformaton on p_rec
* \param p_rec the input/output vector of the values. * \param p_rec the input/output vector of the values.

View File

@ -11,6 +11,7 @@
#include <algorithm> #include <algorithm>
#include <utility> #include <utility>
#include "../common/math.h" #include "../common/math.h"
#include "../common/avx_helpers.h"
namespace xgboost { namespace xgboost {
namespace obj { namespace obj {
@ -20,22 +21,28 @@ DMLC_REGISTRY_FILE_TAG(regression_obj);
// common regressions // common regressions
// linear regression // linear regression
struct LinearSquareLoss { struct LinearSquareLoss {
static bst_float PredTransform(bst_float x) { return x; } template <typename T>
static T PredTransform(T x) { return x; }
static bool CheckLabel(bst_float x) { return true; } static bool CheckLabel(bst_float x) { return true; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; } template <typename T>
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { return 1.0f; } static T FirstOrderGradient(T predt, T label) { return predt - label; }
template <typename T>
static T SecondOrderGradient(T predt, T label) { return T(1.0f); }
static bst_float ProbToMargin(bst_float base_score) { return base_score; } static bst_float ProbToMargin(bst_float base_score) { return base_score; }
static const char* LabelErrorMsg() { return ""; } static const char* LabelErrorMsg() { return ""; }
static const char* DefaultEvalMetric() { return "rmse"; } static const char* DefaultEvalMetric() { return "rmse"; }
}; };
// logistic loss for probability regression task // logistic loss for probability regression task
struct LogisticRegression { struct LogisticRegression {
static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); } template <typename T>
static T PredTransform(T x) { return common::Sigmoid(x); }
static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; } static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; }
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; } template <typename T>
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { static T FirstOrderGradient(T predt, T label) { return predt - label; }
const float eps = 1e-16f; template <typename T>
return std::max(predt * (1.0f - predt), eps); static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f);
return std::max(predt * (T(1.0f) - predt), eps);
} }
static bst_float ProbToMargin(bst_float base_score) { static bst_float ProbToMargin(bst_float base_score) {
CHECK(base_score > 0.0f && base_score < 1.0f) CHECK(base_score > 0.0f && base_score < 1.0f)
@ -53,28 +60,28 @@ struct LogisticClassification : public LogisticRegression {
}; };
// logistic loss, but predict un-transformed margin // logistic loss, but predict un-transformed margin
struct LogisticRaw : public LogisticRegression { struct LogisticRaw : public LogisticRegression {
static bst_float PredTransform(bst_float x) { return x; } template <typename T>
static bst_float FirstOrderGradient(bst_float predt, bst_float label) { static T PredTransform(T x) { return x; }
template <typename T>
static T FirstOrderGradient(T predt, T label) {
predt = common::Sigmoid(predt); predt = common::Sigmoid(predt);
return predt - label; return predt - label;
} }
static bst_float SecondOrderGradient(bst_float predt, bst_float label) { template <typename T>
const float eps = 1e-16f; static T SecondOrderGradient(T predt, T label) {
const T eps = T(1e-16f);
predt = common::Sigmoid(predt); predt = common::Sigmoid(predt);
return std::max(predt * (1.0f - predt), eps); return std::max(predt * (T(1.0f) - predt), eps);
} }
static const char* DefaultEvalMetric() { return "auc"; } static const char* DefaultEvalMetric() { return "auc"; }
}; };
struct RegLossParam : public dmlc::Parameter<RegLossParam> { struct RegLossParam : public dmlc::Parameter<RegLossParam> {
float scale_pos_weight; float scale_pos_weight;
int nthread;
// declare parameters // declare parameters
DMLC_DECLARE_PARAMETER(RegLossParam) { DMLC_DECLARE_PARAMETER(RegLossParam) {
DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f)
.describe("Scale the weight of positive examples by this factor"); .describe("Scale the weight of positive examples by this factor");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
} }
}; };
@ -98,23 +105,36 @@ class RegLossObj : public ObjFunction {
this->LazyCheckLabels(info.labels); this->LazyCheckLabels(info.labels);
out_gpair->resize(preds.size()); out_gpair->resize(preds.size());
const omp_ulong n = static_cast<omp_ulong>(preds.size());
auto gpair_ptr = out_gpair->data();
avx::Float8 scale(param_.scale_pos_weight);
// start calculating gradient const omp_ulong remainder = n % 8;
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); int nthread = omp_get_max_threads();
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; // Use a maximum of 8 threads
#pragma omp parallel for schedule(static) num_threads(nthread) #pragma omp parallel for schedule(static) num_threads(std::min(8, nthread))
for (omp_ulong i = 0; i < ndata; ++i) { for (int i = 0; i < n - remainder; i += 8) {
avx::Float8 y(&info.labels[i]);
avx::Float8 p = Loss::PredTransform(avx::Float8(&preds[i]));
avx::Float8 w = info.weights.empty() ? avx::Float8(1.0f)
: avx::Float8(&info.weights[i]);
// Adjust weight
w += y * (scale * w - w);
avx::Float8 grad = Loss::FirstOrderGradient(p, y);
avx::Float8 hess = Loss::SecondOrderGradient(p, y);
avx::StoreGpair(gpair_ptr + i, grad * w, hess * w);
}
for (omp_ulong i = n - remainder; i < n; ++i) {
auto y = info.labels[i]; auto y = info.labels[i];
bst_float p = Loss::PredTransform(preds[i]); bst_float p = Loss::PredTransform(preds[i]);
bst_float w = info.GetWeight(i); bst_float w = info.GetWeight(i);
// Branchless version of the below function
// The branch is particularly slow as the cpu cannot predict the label
// with any accuracy resulting in frequent pipeline stalls
// if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight;
w += y * ((param_.scale_pos_weight * w) - w); w += y * ((param_.scale_pos_weight * w) - w);
(*out_gpair)[i] = bst_gpair(Loss::FirstOrderGradient(p, y) * w, (*out_gpair)[i] = bst_gpair(Loss::FirstOrderGradient(p, y) * w,
Loss::SecondOrderGradient(p, y) * w); Loss::SecondOrderGradient(p, y) * w);
} }
// Reset omp max threads
omp_set_num_threads(nthread);
} }
const char *DefaultEvalMetric() const override { const char *DefaultEvalMetric() const override {
return Loss::DefaultEvalMetric(); return Loss::DefaultEvalMetric();
@ -165,13 +185,10 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw")
// declare parameter // declare parameter
struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> { struct PoissonRegressionParam : public dmlc::Parameter<PoissonRegressionParam> {
float max_delta_step; float max_delta_step;
int nthread;
DMLC_DECLARE_PARAMETER(PoissonRegressionParam) { DMLC_DECLARE_PARAMETER(PoissonRegressionParam) {
DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f) DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f)
.describe("Maximum delta step we allow each weight estimation to be." \ .describe("Maximum delta step we allow each weight estimation to be." \
" This parameter is required for possion regression."); " This parameter is required for possion regression.");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
} }
}; };
@ -194,8 +211,7 @@ class PoissonRegression : public ObjFunction {
bool label_correct = true; bool label_correct = true;
// start calculating gradient // start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*) const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) num_threads(nthread)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds[i]; bst_float p = preds[i];
bst_float w = info.GetWeight(i); bst_float w = info.GetWeight(i);
@ -212,8 +228,7 @@ class PoissonRegression : public ObjFunction {
void PredTransform(std::vector<bst_float> *io_preds) override { void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds; std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*) const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) num_threads(nthread)
for (long j = 0; j < ndata; ++j) { // NOLINT(*) for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]); preds[j] = std::exp(preds[j]);
} }
@ -297,12 +312,9 @@ XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma")
// declare parameter // declare parameter
struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> { struct TweedieRegressionParam : public dmlc::Parameter<TweedieRegressionParam> {
float tweedie_variance_power; float tweedie_variance_power;
int nthread;
DMLC_DECLARE_PARAMETER(TweedieRegressionParam) { DMLC_DECLARE_PARAMETER(TweedieRegressionParam) {
DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f) DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f)
.describe("Tweedie variance power. Must be between in range [1, 2)."); .describe("Tweedie variance power. Must be between in range [1, 2).");
DMLC_DECLARE_FIELD(nthread).set_default(0).describe(
"Number of threads to use.");
} }
}; };
@ -325,8 +337,7 @@ class TweedieRegression : public ObjFunction {
bool label_correct = true; bool label_correct = true;
// start calculating gradient // start calculating gradient
const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*) const omp_ulong ndata = static_cast<omp_ulong>(preds.size()); // NOLINT(*)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) num_threads(nthread)
for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*)
bst_float p = preds[i]; bst_float p = preds[i];
bst_float w = info.GetWeight(i); bst_float w = info.GetWeight(i);
@ -346,8 +357,7 @@ class TweedieRegression : public ObjFunction {
void PredTransform(std::vector<bst_float> *io_preds) override { void PredTransform(std::vector<bst_float> *io_preds) override {
std::vector<bst_float> &preds = *io_preds; std::vector<bst_float> &preds = *io_preds;
const long ndata = static_cast<long>(preds.size()); // NOLINT(*) const long ndata = static_cast<long>(preds.size()); // NOLINT(*)
int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; #pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static) num_threads(nthread)
for (long j = 0; j < ndata; ++j) { // NOLINT(*) for (long j = 0; j < ndata; ++j) { // NOLINT(*)
preds[j] = std::exp(preds[j]); preds[j] = std::exp(preds[j]);
} }

View File

@ -1,8 +1,8 @@
/*! /*!
* Copyright 2017 XGBoost contributors * Copyright 2017 XGBoost contributors
*/ */
#include <thrust/reduce.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h> #include <thrust/sequence.h>
#include <xgboost/tree_updater.h> #include <xgboost/tree_updater.h>
#include <algorithm> #include <algorithm>
@ -564,7 +564,6 @@ class GPUHistMakerExperimental : public TreeUpdater {
const RegTree& tree) { const RegTree& tree) {
monitor.Start("InitDataOnce"); monitor.Start("InitDataOnce");
if (!initialised) { if (!initialised) {
CheckGradientMax(gpair);
this->InitDataOnce(dmat); this->InitDataOnce(dmat);
} }
monitor.Stop("InitDataOnce"); monitor.Stop("InitDataOnce");
@ -774,6 +773,9 @@ class GPUHistMakerExperimental : public TreeUpdater {
void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat, void UpdateTree(const std::vector<bst_gpair>& gpair, DMatrix* p_fmat,
RegTree* p_tree) { RegTree* p_tree) {
// Temporarily store number of threads so we can change it back later
int nthread = omp_get_max_threads();
auto& tree = *p_tree; auto& tree = *p_tree;
monitor.Start("InitData"); monitor.Start("InitData");
@ -819,6 +821,9 @@ class GPUHistMakerExperimental : public TreeUpdater {
monitor.Stop("EvaluateSplits"); monitor.Stop("EvaluateSplits");
} }
} }
// Reset omp num threads
omp_set_num_threads(nthread);
} }
struct ExpandEntry { struct ExpandEntry {

View File

@ -91,10 +91,30 @@ if [ ${TASK} == "java_test" ]; then
fi fi
if [ ${TASK} == "cmake_test" ]; then if [ ${TASK} == "cmake_test" ]; then
mkdir build set -e
cd build # Build gtest via cmake
cmake .. wget https://github.com/google/googletest/archive/release-1.7.0.zip
unzip release-1.7.0.zip
mv googletest-release-1.7.0 gtest && cd gtest
cmake . && make
mkdir lib && mv libgtest.a lib
cd ..
rm -rf release-1.7.0.zip
# Build/test without AVX
mkdir build && cd build
cmake .. -DGOOGLE_TEST=ON
make make
cd ..
./testxgboost
rm -rf build
# Build/test with AVX
mkdir build && cd build
cmake .. -DGOOGLE_TEST=ON -DUSE_AVX=ON
make
cd ..
./testxgboost
fi fi
if [ ${TASK} == "cpp_test" ]; then if [ ${TASK} == "cpp_test" ]; then