xgboost/plugin/updater_gpu/test/cpp/test_device_helpers.cu

29 lines
974 B
Plaintext

/*!
* Copyright 2017 XGBoost contributors
*/
#include <thrust/device_vector.h>
#include <xgboost/base.h>
#include "../../src/device_helpers.cuh"
#include "gtest/gtest.h"
static const std::vector<int> gidx = {0, 2, 5, 1, 3, 6, 0, 2, 0, 7};
static const std::vector<int> row_ptr = {0, 3, 6, 8, 10};
static const std::vector<int> lbs_seg_output = {0, 0, 0, 1, 1, 1, 2, 2, 3, 3};
thrust::device_vector<int> test_lbs() {
thrust::device_vector<int> device_gidx = gidx;
thrust::device_vector<int> device_row_ptr = row_ptr;
thrust::device_vector<int> device_output_row(gidx.size(), 0);
auto d_output_row = device_output_row.data();
dh::CubMemory temp_memory;
dh::TransformLbs(
0, &temp_memory, gidx.size(), device_row_ptr.data(), row_ptr.size() - 1,
[=] __device__(int idx, int ridx) { d_output_row[idx] = ridx; });
dh::safe_cuda(cudaDeviceSynchronize());
return device_output_row;
}
TEST(lbs, Test) { ASSERT_TRUE(test_lbs() == lbs_seg_output); }