Define the new device parameter. (#9362)
This commit is contained in:
@@ -3,16 +3,17 @@
|
||||
*/
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstddef> // std::size_t
|
||||
#include <utility> // std::pair
|
||||
#include <vector> // std::vector
|
||||
#include <cstddef> // std::size_t
|
||||
#include <utility> // std::pair
|
||||
#include <vector> // std::vector
|
||||
|
||||
#include "../../../src/common/linalg_op.cuh" // ElementWiseTransformDevice
|
||||
#include "../../../src/common/stats.cuh"
|
||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
#include "../helpers.h"
|
||||
#include "xgboost/base.h" // XGBOOST_DEVICE
|
||||
#include "xgboost/context.h" // Context
|
||||
#include "xgboost/host_device_vector.h" // HostDeviceVector
|
||||
#include "xgboost/linalg.h" // Tensor
|
||||
|
||||
namespace xgboost {
|
||||
namespace common {
|
||||
@@ -33,7 +34,7 @@ class StatsGPU : public ::testing::Test {
|
||||
}
|
||||
|
||||
public:
|
||||
void SetUp() override { ctx_.gpu_id = 0; }
|
||||
void SetUp() override { ctx_ = MakeCUDACtx(0); }
|
||||
|
||||
void WeightedMulti() {
|
||||
// data for one segment
|
||||
|
||||
Reference in New Issue
Block a user