diff --git a/CMakeLists.txt b/CMakeLists.txt index 18a2ec04f..22e8c4760 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ msvc_use_static_runtime() # Options option(USE_CUDA "Build with GPU acceleration") +option(USE_AVX "Build with AVX instructions. May not produce identical results due to approximate math." OFF) option(USE_NCCL "Build using NCCL for multi-GPU. Also requires USE_CUDA") option(JVM_BINDINGS "Build JVM bindings" OFF) option(GOOGLE_TEST "Build google tests" OFF) @@ -39,6 +40,17 @@ else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -funroll-loops") endif() +# AVX +if(USE_AVX) + if(MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx") + endif() + add_definitions(-DXGBOOST_USE_AVX) +endif() + + # compiled code customizations for R package if(R_LIB) add_definitions( @@ -212,12 +224,12 @@ endif() # Test if(GOOGLE_TEST) - enable_testing() find_package(GTest REQUIRED) + enable_testing() file(GLOB_RECURSE TEST_SOURCES "tests/cpp/*.cc") auto_source_group("${TEST_SOURCES}") - include_directories(${GTEST_INCLUDE_DIRS}) + include_directories(${GTEST_INCLUDE_DIR}) if(USE_CUDA) file(GLOB_RECURSE CUDA_TEST_SOURCES "tests/cpp/*.cu") @@ -228,7 +240,7 @@ if(GOOGLE_TEST) add_executable(testxgboost ${TEST_SOURCES} ${CUDA_TEST_OBJS} $) set_output_directory(testxgboost ${PROJECT_SOURCE_DIR}) - target_link_libraries(testxgboost ${GTEST_BOTH_LIBRARIES} ${LINK_LIBRARIES}) + target_link_libraries(testxgboost ${GTEST_LIBRARIES} ${LINK_LIBRARIES}) add_test(TestXGBoost testxgboost) endif() diff --git a/cmake/modules/FindGTest.cmake b/cmake/modules/FindGTest.cmake new file mode 100644 index 000000000..52fec5991 --- /dev/null +++ b/cmake/modules/FindGTest.cmake @@ -0,0 +1,79 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Tries to find GTest headers and libraries. +# +# Usage of this module as follows: +# +# find_package(GTest) +# +# Variables used by this module, they can change the default behaviour and need +# to be set before calling find_package: +# +# GTest_HOME - When set, this path is inspected instead of standard library +# locations as the root of the GTest installation. +# The environment variable GTEST_HOME overrides this veriable. +# +# This module defines +# GTEST_INCLUDE_DIR, directory containing headers +# GTEST_LIBS, directory containing gtest libraries +# GTEST_STATIC_LIB, path to libgtest.a +# GTEST_SHARED_LIB, path to libgtest's shared library +# GTEST_FOUND, whether gtest has been found + +find_path(GTEST_INCLUDE_DIR NAMES gtest/gtest.h gtest.h PATHS ${CMAKE_SOURCE_DIR}/gtest/include NO_DEFAULT_PATH) +find_library(GTEST_LIBRARIES NAMES gtest PATHS ${CMAKE_SOURCE_DIR}/gtest/lib NO_DEFAULT_PATH) + +if (GTEST_INCLUDE_DIR ) + message(STATUS "Found the GTest includes: ${GTEST_INCLUDE_DIR}") +endif () + + +if (GTEST_INCLUDE_DIR AND GTEST_LIBRARIES) + set(GTEST_FOUND TRUE) + get_filename_component( GTEST_LIBS ${GTEST_LIBRARIES} PATH ) + set(GTEST_LIB_NAME gtest) + set(GTEST_STATIC_LIB ${GTEST_LIBS}/${CMAKE_STATIC_LIBRARY_PREFIX}${GTEST_LIB_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(GTEST_MAIN_STATIC_LIB ${GTEST_LIBS}/${CMAKE_STATIC_LIBRARY_PREFIX}${GTEST_LIB_NAME}_main${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(GTEST_SHARED_LIB ${GTEST_LIBS}/${CMAKE_SHARED_LIBRARY_PREFIX}${GTEST_LIB_NAME}${CMAKE_SHARED_LIBRARY_SUFFIX}) +else () + set(GTEST_FOUND FALSE) +endif () + +if (GTEST_FOUND) + if (NOT GTest_FIND_QUIETLY) + message(STATUS "Found the GTest library: ${GTEST_LIBRARIES}") + endif () +else () + if (NOT GTest_FIND_QUIETLY) + set(GTEST_ERR_MSG "Could not find the GTest library. Looked in ") + if ( _gtest_roots ) + set(GTEST_ERR_MSG "${GTEST_ERR_MSG} in ${_gtest_roots}.") + else () + set(GTEST_ERR_MSG "${GTEST_ERR_MSG} system search paths.") + endif () + if (GTest_FIND_REQUIRED) + message(FATAL_ERROR "${GTEST_ERR_MSG}") + else (GTest_FIND_REQUIRED) + message(STATUS "${GTEST_ERR_MSG}") + endif (GTest_FIND_REQUIRED) + endif () +endif () + +mark_as_advanced( + GTEST_INCLUDE_DIR + GTEST_LIBS + GTEST_LIBRARIES + GTEST_STATIC_LIB + GTEST_SHARED_LIB +) diff --git a/src/common/avx_helpers.h b/src/common/avx_helpers.h new file mode 100644 index 000000000..034a90e26 --- /dev/null +++ b/src/common/avx_helpers.h @@ -0,0 +1,287 @@ +/*! + * Copyright 2017 by Contributors + * \author Rory Mitchell + */ +#pragma once +#include +#include "xgboost/base.h" + +#ifdef XGBOOST_USE_AVX +namespace avx { +/** + * \struct Float8 + * + * \brief Helper class for processing a vector of eight floats using AVX + * instructions. Implements basic math operators. + */ + +struct Float8 { + __m256 x; + explicit Float8(const __m256& x) : x(x) {} + explicit Float8(const float& val) : x(_mm256_broadcast_ss(&val)) {} + explicit Float8(const float* vec) : x(_mm256_loadu_ps(vec)) {} + Float8() : x() {} + Float8& operator+=(const Float8& rhs) { + x = _mm256_add_ps(x, rhs.x); + return *this; + } + Float8& operator-=(const Float8& rhs) { + x = _mm256_sub_ps(x, rhs.x); + return *this; + } + Float8& operator*=(const Float8& rhs) { + x = _mm256_mul_ps(x, rhs.x); + return *this; + } + Float8& operator/=(const Float8& rhs) { + x = _mm256_div_ps(x, rhs.x); + return *this; + } + void Print() { + float* f = reinterpret_cast(&x); + printf("%f %f %f %f %f %f %f %f\n", f[0], f[1], f[2], f[3], f[4], f[5], + f[6], f[7]); + } +}; + +inline Float8 operator+(Float8 lhs, const Float8& rhs) { + lhs += rhs; + return lhs; +} +inline Float8 operator-(Float8 lhs, const Float8& rhs) { + lhs -= rhs; + return lhs; +} +inline Float8 operator*(Float8 lhs, const Float8& rhs) { + lhs *= rhs; + return lhs; +} +inline Float8 operator/(Float8 lhs, const Float8& rhs) { + lhs /= rhs; + return lhs; +} + +inline Float8 round(const Float8& x) { + return Float8(_mm256_round_ps(x.x, _MM_FROUND_TO_NEAREST_INT)); +} +} // namespace avx + +// Overload std::max/min +namespace std { +inline avx::Float8 max(const avx::Float8& a, const avx::Float8& b) { + return avx::Float8(_mm256_max_ps(a.x, b.x)); +} +inline avx::Float8 min(const avx::Float8& a, const avx::Float8& b) { + return avx::Float8(_mm256_min_ps(a.x, b.x)); +} +} // namespace std + +namespace avx { + +// https://codingforspeed.com/using-faster-exponential-approximation/ +inline Float8 Exp4096(Float8 x) { + x *= Float8(1.0f / 4096.0f); + x += Float8(1.0f); + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + x *= x; + return x; +} + +inline Float8 pow2n(Float8 const& n) { + const float pow2_23 = 8388608.0; // 2^23 + const float bias = 127.0; // bias in exponent + Float8 a = + n + Float8(bias + pow2_23); // put n + bias in least significant bits + __m256i b = _mm256_castps_si256(a.x); + + // Do bit shift in SSE so we don't have to use AVX2 instructions + __m128i c1 = _mm256_castsi256_si128(b); + b = _mm256_permute2f128_si256(b, b, 1); + __m128i c2 = _mm256_castsi256_si128(b); + c1 = _mm_slli_epi32(c1, 23); + c2 = _mm_slli_epi32(c2, 23); + + __m256i c = _mm256_insertf128_si256(_mm256_castsi128_si256(c1), (c2), 0x1); + return Float8(_mm256_castsi256_ps(c)); +} + +inline Float8 polynomial_5(Float8 const& x, const float c0, const float c1, + const float c2, const float c3, const float c4, + const float c5) { + // calculates polynomial c5*x^5 + c4*x^4 + c3*x^3 + c2*x^2 + c1*x + c0 + Float8 x2 = x * x; + Float8 x4 = x2 * x2; + return (Float8(c2) + Float8(c3) * x) * x2 + + ((Float8(c4) + Float8(c5) * x) * x4 + (Float8(c0) + Float8(c1) * x)); +} + +// AVX exp Function based off Agner Fog's vector library +// https://github.com/darealshinji/vectorclass/blob/master/vectormath_exp.h +// Modified so it doesn't require AVX2 instructions +// Clamps input values to the range -87.3f, +87.3f +inline Float8 ExpAgner(Float8 x) { + // Clamp input values + float max_x = 87.3f; + x = std::min(x, Float8(max_x)); + x = std::max(x, Float8(-max_x)); + + // 1/log(2) + const float log2e = 1.44269504088896340736f; + + // Taylor coefficients + const float P0expf = 1.f / 2.f; + const float P1expf = 1.f / 6.f; + const float P2expf = 1.f / 24.f; + const float P3expf = 1.f / 120.f; + const float P4expf = 1.f / 720.f; + const float P5expf = 1.f / 5040.f; + + const float ln2f_hi = 0.693359375f; + const float ln2f_lo = -2.12194440e-4f; + + Float8 r = round(x * Float8(log2e)); + x -= r * Float8(ln2f_hi); + x -= r * Float8(ln2f_lo); + + Float8 x2 = x * x; + Float8 z = polynomial_5(x, P0expf, P1expf, P2expf, P3expf, P4expf, P5expf); + z *= x2; + z += x; + + // multiply by power of 2 + Float8 n2 = pow2n(r); + + z = (z + Float8(1.0f)) * n2; + return z; +} + +inline Float8 Sigmoid(Float8 x) { + Float8 exp = ExpAgner(x * Float8(-1.0f)); + x = Float8(1.0f) + exp; + return Float8(_mm256_rcp_ps(x.x)); +} + +// Store 8 gradient pairs given vectors containing gradient and Hessian +inline void StoreGpair(xgboost::bst_gpair* dst, const Float8& grad, + const Float8& hess) { + float* ptr = reinterpret_cast(dst); + __m256 gpair_low = _mm256_unpacklo_ps(grad.x, hess.x); + __m256 gpair_high = _mm256_unpackhi_ps(grad.x, hess.x); + _mm256_storeu_ps(ptr, _mm256_permute2f128_ps(gpair_low, gpair_high, 0x20)); + _mm256_storeu_ps(ptr + 8, + _mm256_permute2f128_ps(gpair_low, gpair_high, 0x31)); +} +} // namespace avx +#else +namespace avx { +/** + * \struct Float8 + * + * \brief Fallback implementation not using AVX. + */ + +struct Float8 { + float x[8]; + explicit Float8(const float& val) { + for (int i = 0; i < 8; i++) { + x[i] = val; + } + } + explicit Float8(const float* vec) { + for (int i = 0; i < 8; i++) { + x[i] = vec[i]; + } + } + Float8() {} + Float8& operator+=(const Float8& rhs) { + for (int i = 0; i < 8; i++) { + x[i] += rhs.x[i]; + } + return *this; + } + Float8& operator-=(const Float8& rhs) { + for (int i = 0; i < 8; i++) { + x[i] -= rhs.x[i]; + } + return *this; + } + Float8& operator*=(const Float8& rhs) { + for (int i = 0; i < 8; i++) { + x[i] *= rhs.x[i]; + } + return *this; + } + Float8& operator/=(const Float8& rhs) { + for (int i = 0; i < 8; i++) { + x[i] /= rhs.x[i]; + } + return *this; + } + void Print() { + float* f = reinterpret_cast(&x); + printf("%f %f %f %f %f %f %f %f\n", f[0], f[1], f[2], f[3], f[4], f[5], + f[6], f[7]); + } +}; + +inline Float8 operator+(Float8 lhs, const Float8& rhs) { + lhs += rhs; + return lhs; +} +inline Float8 operator-(Float8 lhs, const Float8& rhs) { + lhs -= rhs; + return lhs; +} +inline Float8 operator*(Float8 lhs, const Float8& rhs) { + lhs *= rhs; + return lhs; +} +inline Float8 operator/(Float8 lhs, const Float8& rhs) { + lhs /= rhs; + return lhs; +} + +// Store 8 gradient pairs given vectors containing gradient and Hessian +inline void StoreGpair(xgboost::bst_gpair* dst, const Float8& grad, + const Float8& hess) { + for (int i = 0; i < 8; i++) { + dst[i] = xgboost::bst_gpair(grad.x[i], hess.x[i]); + } +} + +inline Float8 Sigmoid(Float8 x) { + Float8 sig; + for (int i = 0; i < 8; i++) { + sig.x[i] = 1.0f / (1.0f + std::exp(-x.x[i])); + } + return sig; +} +} // namespace avx + +namespace std { +inline avx::Float8 max(const avx::Float8& a, const avx::Float8& b) { + avx::Float8 max; + for (int i = 0; i < 8; i++) { + max.x[i] = std::max(a.x[i], b.x[i]); + } + return max; +} +inline avx::Float8 min(const avx::Float8& a, const avx::Float8& b) { + avx::Float8 min; + for (int i = 0; i < 8; i++) { + min.x[i] = std::min(a.x[i], b.x[i]); + } + return min; +} +} // namespace std +#endif diff --git a/src/common/math.h b/src/common/math.h index 518da23bb..6e594032e 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -11,6 +11,7 @@ #include #include #include +#include "avx_helpers.h" namespace xgboost { namespace common { @@ -23,6 +24,10 @@ inline float Sigmoid(float x) { return 1.0f / (1.0f + std::exp(-x)); } +inline avx::Float8 Sigmoid(avx::Float8 x) { + return avx::Sigmoid(x); +} + /*! * \brief do inplace softmax transformaton on p_rec * \param p_rec the input/output vector of the values. diff --git a/src/objective/regression_obj.cc b/src/objective/regression_obj.cc index 2597d2de4..83db376c8 100644 --- a/src/objective/regression_obj.cc +++ b/src/objective/regression_obj.cc @@ -11,6 +11,7 @@ #include #include #include "../common/math.h" +#include "../common/avx_helpers.h" namespace xgboost { namespace obj { @@ -20,22 +21,28 @@ DMLC_REGISTRY_FILE_TAG(regression_obj); // common regressions // linear regression struct LinearSquareLoss { - static bst_float PredTransform(bst_float x) { return x; } + template + static T PredTransform(T x) { return x; } static bool CheckLabel(bst_float x) { return true; } - static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; } - static bst_float SecondOrderGradient(bst_float predt, bst_float label) { return 1.0f; } + template + static T FirstOrderGradient(T predt, T label) { return predt - label; } + template + static T SecondOrderGradient(T predt, T label) { return T(1.0f); } static bst_float ProbToMargin(bst_float base_score) { return base_score; } static const char* LabelErrorMsg() { return ""; } static const char* DefaultEvalMetric() { return "rmse"; } }; // logistic loss for probability regression task struct LogisticRegression { - static bst_float PredTransform(bst_float x) { return common::Sigmoid(x); } + template + static T PredTransform(T x) { return common::Sigmoid(x); } static bool CheckLabel(bst_float x) { return x >= 0.0f && x <= 1.0f; } - static bst_float FirstOrderGradient(bst_float predt, bst_float label) { return predt - label; } - static bst_float SecondOrderGradient(bst_float predt, bst_float label) { - const float eps = 1e-16f; - return std::max(predt * (1.0f - predt), eps); + template + static T FirstOrderGradient(T predt, T label) { return predt - label; } + template + static T SecondOrderGradient(T predt, T label) { + const T eps = T(1e-16f); + return std::max(predt * (T(1.0f) - predt), eps); } static bst_float ProbToMargin(bst_float base_score) { CHECK(base_score > 0.0f && base_score < 1.0f) @@ -53,28 +60,28 @@ struct LogisticClassification : public LogisticRegression { }; // logistic loss, but predict un-transformed margin struct LogisticRaw : public LogisticRegression { - static bst_float PredTransform(bst_float x) { return x; } - static bst_float FirstOrderGradient(bst_float predt, bst_float label) { + template + static T PredTransform(T x) { return x; } + template + static T FirstOrderGradient(T predt, T label) { predt = common::Sigmoid(predt); return predt - label; } - static bst_float SecondOrderGradient(bst_float predt, bst_float label) { - const float eps = 1e-16f; + template + static T SecondOrderGradient(T predt, T label) { + const T eps = T(1e-16f); predt = common::Sigmoid(predt); - return std::max(predt * (1.0f - predt), eps); + return std::max(predt * (T(1.0f) - predt), eps); } static const char* DefaultEvalMetric() { return "auc"; } }; struct RegLossParam : public dmlc::Parameter { float scale_pos_weight; - int nthread; // declare parameters DMLC_DECLARE_PARAMETER(RegLossParam) { DMLC_DECLARE_FIELD(scale_pos_weight).set_default(1.0f).set_lower_bound(0.0f) .describe("Scale the weight of positive examples by this factor"); - DMLC_DECLARE_FIELD(nthread).set_default(0).describe( - "Number of threads to use."); } }; @@ -98,23 +105,36 @@ class RegLossObj : public ObjFunction { this->LazyCheckLabels(info.labels); out_gpair->resize(preds.size()); + const omp_ulong n = static_cast(preds.size()); + auto gpair_ptr = out_gpair->data(); + avx::Float8 scale(param_.scale_pos_weight); - // start calculating gradient - const omp_ulong ndata = static_cast(preds.size()); - int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; -#pragma omp parallel for schedule(static) num_threads(nthread) - for (omp_ulong i = 0; i < ndata; ++i) { + const omp_ulong remainder = n % 8; + int nthread = omp_get_max_threads(); + // Use a maximum of 8 threads +#pragma omp parallel for schedule(static) num_threads(std::min(8, nthread)) + for (int i = 0; i < n - remainder; i += 8) { + avx::Float8 y(&info.labels[i]); + avx::Float8 p = Loss::PredTransform(avx::Float8(&preds[i])); + avx::Float8 w = info.weights.empty() ? avx::Float8(1.0f) + : avx::Float8(&info.weights[i]); + // Adjust weight + w += y * (scale * w - w); + avx::Float8 grad = Loss::FirstOrderGradient(p, y); + avx::Float8 hess = Loss::SecondOrderGradient(p, y); + avx::StoreGpair(gpair_ptr + i, grad * w, hess * w); + } + for (omp_ulong i = n - remainder; i < n; ++i) { auto y = info.labels[i]; bst_float p = Loss::PredTransform(preds[i]); bst_float w = info.GetWeight(i); - // Branchless version of the below function - // The branch is particularly slow as the cpu cannot predict the label - // with any accuracy resulting in frequent pipeline stalls - // if (info.labels[i] == 1.0f) w *= param_.scale_pos_weight; w += y * ((param_.scale_pos_weight * w) - w); (*out_gpair)[i] = bst_gpair(Loss::FirstOrderGradient(p, y) * w, Loss::SecondOrderGradient(p, y) * w); } + + // Reset omp max threads + omp_set_num_threads(nthread); } const char *DefaultEvalMetric() const override { return Loss::DefaultEvalMetric(); @@ -165,13 +185,10 @@ XGBOOST_REGISTER_OBJECTIVE(LogisticRaw, "binary:logitraw") // declare parameter struct PoissonRegressionParam : public dmlc::Parameter { float max_delta_step; - int nthread; DMLC_DECLARE_PARAMETER(PoissonRegressionParam) { DMLC_DECLARE_FIELD(max_delta_step).set_lower_bound(0.0f).set_default(0.7f) .describe("Maximum delta step we allow each weight estimation to be." \ " This parameter is required for possion regression."); - DMLC_DECLARE_FIELD(nthread).set_default(0).describe( - "Number of threads to use."); } }; @@ -194,8 +211,7 @@ class PoissonRegression : public ObjFunction { bool label_correct = true; // start calculating gradient const omp_ulong ndata = static_cast(preds.size()); // NOLINT(*) - int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; - #pragma omp parallel for schedule(static) num_threads(nthread) +#pragma omp parallel for schedule(static) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds[i]; bst_float w = info.GetWeight(i); @@ -212,8 +228,7 @@ class PoissonRegression : public ObjFunction { void PredTransform(std::vector *io_preds) override { std::vector &preds = *io_preds; const long ndata = static_cast(preds.size()); // NOLINT(*) - int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; - #pragma omp parallel for schedule(static) num_threads(nthread) +#pragma omp parallel for schedule(static) for (long j = 0; j < ndata; ++j) { // NOLINT(*) preds[j] = std::exp(preds[j]); } @@ -297,12 +312,9 @@ XGBOOST_REGISTER_OBJECTIVE(GammaRegression, "reg:gamma") // declare parameter struct TweedieRegressionParam : public dmlc::Parameter { float tweedie_variance_power; - int nthread; DMLC_DECLARE_PARAMETER(TweedieRegressionParam) { DMLC_DECLARE_FIELD(tweedie_variance_power).set_range(1.0f, 2.0f).set_default(1.5f) .describe("Tweedie variance power. Must be between in range [1, 2)."); - DMLC_DECLARE_FIELD(nthread).set_default(0).describe( - "Number of threads to use."); } }; @@ -325,8 +337,7 @@ class TweedieRegression : public ObjFunction { bool label_correct = true; // start calculating gradient const omp_ulong ndata = static_cast(preds.size()); // NOLINT(*) - int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; - #pragma omp parallel for schedule(static) num_threads(nthread) + #pragma omp parallel for schedule(static) for (omp_ulong i = 0; i < ndata; ++i) { // NOLINT(*) bst_float p = preds[i]; bst_float w = info.GetWeight(i); @@ -346,8 +357,7 @@ class TweedieRegression : public ObjFunction { void PredTransform(std::vector *io_preds) override { std::vector &preds = *io_preds; const long ndata = static_cast(preds.size()); // NOLINT(*) - int nthread = param_.nthread == 0 ? omp_get_num_procs() : param_.nthread; - #pragma omp parallel for schedule(static) num_threads(nthread) +#pragma omp parallel for schedule(static) for (long j = 0; j < ndata; ++j) { // NOLINT(*) preds[j] = std::exp(preds[j]); } diff --git a/src/tree/updater_gpu_hist_experimental.cu b/src/tree/updater_gpu_hist_experimental.cu index d351fa183..67e71a3f1 100644 --- a/src/tree/updater_gpu_hist_experimental.cu +++ b/src/tree/updater_gpu_hist_experimental.cu @@ -1,8 +1,8 @@ /*! * Copyright 2017 XGBoost contributors */ -#include #include +#include #include #include #include @@ -564,7 +564,6 @@ class GPUHistMakerExperimental : public TreeUpdater { const RegTree& tree) { monitor.Start("InitDataOnce"); if (!initialised) { - CheckGradientMax(gpair); this->InitDataOnce(dmat); } monitor.Stop("InitDataOnce"); @@ -774,6 +773,9 @@ class GPUHistMakerExperimental : public TreeUpdater { void UpdateTree(const std::vector& gpair, DMatrix* p_fmat, RegTree* p_tree) { + // Temporarily store number of threads so we can change it back later + int nthread = omp_get_max_threads(); + auto& tree = *p_tree; monitor.Start("InitData"); @@ -819,6 +821,9 @@ class GPUHistMakerExperimental : public TreeUpdater { monitor.Stop("EvaluateSplits"); } } + + // Reset omp num threads + omp_set_num_threads(nthread); } struct ExpandEntry { diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index f1fea3273..92e0ccb8f 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -91,10 +91,30 @@ if [ ${TASK} == "java_test" ]; then fi if [ ${TASK} == "cmake_test" ]; then - mkdir build - cd build - cmake .. + set -e + # Build gtest via cmake + wget https://github.com/google/googletest/archive/release-1.7.0.zip + unzip release-1.7.0.zip + mv googletest-release-1.7.0 gtest && cd gtest + cmake . && make + mkdir lib && mv libgtest.a lib + cd .. + rm -rf release-1.7.0.zip + + # Build/test without AVX + mkdir build && cd build + cmake .. -DGOOGLE_TEST=ON make + cd .. + ./testxgboost + rm -rf build + + # Build/test with AVX + mkdir build && cd build + cmake .. -DGOOGLE_TEST=ON -DUSE_AVX=ON + make + cd .. + ./testxgboost fi if [ ${TASK} == "cpp_test" ]; then