Support dmatrix construction from cupy array (#5206)
This commit is contained in:
@@ -6,7 +6,8 @@
|
||||
#include <thrust/sequence.h>
|
||||
#include "../../../src/data/device_adapter.cuh"
|
||||
#include "../helpers.h"
|
||||
#include "test_columnar.h"
|
||||
#include "test_array_interface.h"
|
||||
#include "../../../src/data/array_interface.h"
|
||||
|
||||
using namespace xgboost; // NOLINT
|
||||
|
||||
@@ -316,3 +317,55 @@ TEST(SimpleDMatrix, FromColumnarSparseBasic) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST(SimpleDMatrix, FromCupy){
|
||||
int rows = 50;
|
||||
int cols = 10;
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
data::CupyAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, -1, 1);
|
||||
EXPECT_EQ(dmat.Info().num_col_, cols);
|
||||
EXPECT_EQ(dmat.Info().num_row_, rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, rows*cols);
|
||||
|
||||
for (auto& batch : dmat.GetBatches<SparsePage>()) {
|
||||
for (auto i = 0ull; i < batch.Size(); i++) {
|
||||
auto inst = batch[i];
|
||||
for (auto j = 0ull; j < inst.size(); j++) {
|
||||
EXPECT_EQ(inst[j].fvalue, i * cols + j);
|
||||
EXPECT_EQ(inst[j].index, j);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SimpleDMatrix, FromCupySparse){
|
||||
int rows = 2;
|
||||
int cols = 2;
|
||||
thrust::device_vector< float> data(rows*cols);
|
||||
auto json_array_interface = Generate2dArrayInterface(rows, cols, "<f4", &data);
|
||||
data[1] = std::numeric_limits<float>::quiet_NaN();
|
||||
data[2] = std::numeric_limits<float>::quiet_NaN();
|
||||
std::stringstream ss;
|
||||
Json::Dump(json_array_interface, &ss);
|
||||
std::string str = ss.str();
|
||||
data::CupyAdapter adapter(str);
|
||||
data::SimpleDMatrix dmat(&adapter, -1, 1);
|
||||
EXPECT_EQ(dmat.Info().num_col_, cols);
|
||||
EXPECT_EQ(dmat.Info().num_row_, rows);
|
||||
EXPECT_EQ(dmat.Info().num_nonzero_, rows * cols - 2);
|
||||
auto& batch = *dmat.GetBatches<SparsePage>().begin();
|
||||
auto inst0 = batch[0];
|
||||
auto inst1 = batch[1];
|
||||
EXPECT_EQ(batch[0].size(), 1);
|
||||
EXPECT_EQ(batch[1].size(), 1);
|
||||
EXPECT_EQ(batch[0][0].fvalue, 0.0f);
|
||||
EXPECT_EQ(batch[0][0].index, 0);
|
||||
EXPECT_EQ(batch[1][0].fvalue, 3.0f);
|
||||
EXPECT_EQ(batch[1][0].index, 1);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user