Cudf support. (#4745)
* Initial support for cudf integration. * Add two C APIs for consuming data and metainfo. * Add CopyFrom for SimpleCSRSource as a generic function to consume the data. * Add FromDeviceColumnar for consuming device data. * Add new MetaInfo::SetInfo for consuming label, weight etc.
This commit is contained in:
committed by
Rory Mitchell
parent
ab357dd41c
commit
9700776597
@@ -16,7 +16,7 @@ namespace xgboost {
|
||||
namespace {
|
||||
|
||||
struct FConstraintWrapper : public FeatureInteractionConstraint {
|
||||
common::Span<BitField> GetNodeConstraints() {
|
||||
common::Span<LBitField64> GetNodeConstraints() {
|
||||
return FeatureInteractionConstraint::s_node_constraints_;
|
||||
}
|
||||
FConstraintWrapper(tree::TrainParam param, int32_t n_features) :
|
||||
@@ -44,13 +44,13 @@ tree::TrainParam GetParameter() {
|
||||
return param;
|
||||
}
|
||||
|
||||
void CompareBitField(BitField d_field, std::set<uint32_t> positions) {
|
||||
std::vector<BitField::value_type> h_field_storage(d_field.bits_.size());
|
||||
thrust::copy(thrust::device_ptr<BitField::value_type>(d_field.bits_.data()),
|
||||
thrust::device_ptr<BitField::value_type>(
|
||||
void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
|
||||
std::vector<LBitField64::value_type> h_field_storage(d_field.bits_.size());
|
||||
thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_field.bits_.data()),
|
||||
thrust::device_ptr<LBitField64::value_type>(
|
||||
d_field.bits_.data() + d_field.bits_.size()),
|
||||
h_field_storage.data());
|
||||
BitField h_field;
|
||||
LBitField64 h_field;
|
||||
h_field.bits_ = {h_field_storage.data(), h_field_storage.data() + h_field_storage.size()};
|
||||
|
||||
for (size_t i = 0; i < h_field.Size(); ++i) {
|
||||
@@ -71,14 +71,14 @@ TEST(FeatureInteractionConstraint, Init) {
|
||||
tree::TrainParam param = GetParameter();
|
||||
FConstraintWrapper constraints(param, kFeatures);
|
||||
ASSERT_EQ(constraints.Features(), kFeatures);
|
||||
common::Span<BitField> s_nodes_constraints = constraints.GetNodeConstraints();
|
||||
for (BitField const& d_node : s_nodes_constraints) {
|
||||
std::vector<BitField::value_type> h_node_storage(d_node.bits_.size());
|
||||
thrust::copy(thrust::device_ptr<BitField::value_type>(d_node.bits_.data()),
|
||||
thrust::device_ptr<BitField::value_type>(
|
||||
common::Span<LBitField64> s_nodes_constraints = constraints.GetNodeConstraints();
|
||||
for (LBitField64 const& d_node : s_nodes_constraints) {
|
||||
std::vector<LBitField64::value_type> h_node_storage(d_node.bits_.size());
|
||||
thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_node.bits_.data()),
|
||||
thrust::device_ptr<LBitField64::value_type>(
|
||||
d_node.bits_.data() + d_node.bits_.size()),
|
||||
h_node_storage.data());
|
||||
BitField h_node;
|
||||
LBitField64 h_node;
|
||||
h_node.bits_ = {h_node_storage.data(), h_node_storage.data() + h_node_storage.size()};
|
||||
// no feature is attached to node.
|
||||
for (size_t i = 0; i < h_node.Size(); ++i) {
|
||||
@@ -108,7 +108,7 @@ TEST(FeatureInteractionConstraint, Init) {
|
||||
}
|
||||
|
||||
{
|
||||
// Test having more than 1 BitField::value_type
|
||||
// Test having more than 1 LBitField64::value_type
|
||||
int32_t constexpr kFeatures = 129;
|
||||
tree::TrainParam param = GetParameter();
|
||||
param.interaction_constraints = R"([[0, 1, 3], [3, 5, 128], [127, 128]])";
|
||||
@@ -129,7 +129,7 @@ TEST(FeatureInteractionConstraint, Split) {
|
||||
FConstraintWrapper constraints(param, kFeatures);
|
||||
|
||||
{
|
||||
BitField d_node[3];
|
||||
LBitField64 d_node[3];
|
||||
constraints.Split(0, /*feature_id=*/1, 1, 2);
|
||||
for (size_t nid = 0; nid < 3; ++nid) {
|
||||
d_node[nid] = constraints.GetNodeConstraints()[nid];
|
||||
@@ -139,7 +139,7 @@ TEST(FeatureInteractionConstraint, Split) {
|
||||
}
|
||||
|
||||
{
|
||||
BitField d_node[5];
|
||||
LBitField64 d_node[5];
|
||||
constraints.Split(1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
|
||||
for (auto nid : {1, 3, 4}) {
|
||||
d_node[nid] = constraints.GetNodeConstraints()[nid];
|
||||
|
||||
Reference in New Issue
Block a user