Cudf support. (#4745)

* Initial support for cudf integration.

* Add two C APIs for consuming data and metainfo.

* Add CopyFrom for SimpleCSRSource as a generic function to consume the data.

* Add FromDeviceColumnar for consuming device data.

* Add new MetaInfo::SetInfo for consuming label, weight etc.
This commit is contained in:
Jiaming Yuan 2019-08-19 00:51:40 -04:00 committed by Rory Mitchell
parent ab357dd41c
commit 9700776597
26 changed files with 1385 additions and 287 deletions

View File

@ -10,17 +10,17 @@
#include <dmlc/base.h>
#include <dmlc/data.h>
#include <rabit/rabit.h>
#include <cstring>
#include <xgboost/base.h>
#include <memory>
#include <numeric>
#include <algorithm>
#include <string>
#include <utility>
#include <vector>
#include "./base.h"
#include "../../src/common/span.h"
#include "../../src/common/group_data.h"
#include "../../src/common/host_device_vector.h"
namespace xgboost {
@ -121,6 +121,12 @@ class MetaInfo {
* \param num Number of elements in the source array.
*/
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
* \param interface_str String representation of json format array interface.
*/
void SetInfo(const char* key, std::string const& interface_str);
private:
/*! \brief argsort of labels */

View File

@ -4,8 +4,6 @@
#ifndef XGBOOST_JSON_H_
#define XGBOOST_JSON_H_
#include <dmlc/io.h>
#include <xgboost/logging.h>
#include <string>

View File

@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015 by Contributors
* Copyright (c) 2015-2019 by Contributors
* \file logging.h
* \brief defines console logging options for xgboost.
* Use to enforce unified print behavior.
@ -11,12 +11,13 @@
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <dmlc/thread_local.h>
#include <xgboost/base.h>
#include <sstream>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "./base.h"
namespace xgboost {
@ -154,5 +155,14 @@ using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;
::xgboost::ConsoleLogger::LogVerbosity::kIgnore)
// Enable LOG(TRACKER) for print messages to tracker
#define LOG_TRACKER ::xgboost::TrackerLogger()
#if defined(CHECK)
#undef CHECK
#define CHECK(cond) \
if (XGBOOST_EXPECT(!(cond), false)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< "Check failed: " #cond << ": "
#endif // defined(CHECK)
} // namespace xgboost.
#endif // XGBOOST_LOGGING_H_

View File

@ -124,9 +124,16 @@ except ImportError:
class DataTable(object):
""" dummy for datatable.DataTable """
DT_INSTALLED = False
try:
from cudf import DataFrame as CUDF_DataFrame
CUDF_INSTALLED = True
except ImportError:
CUDF_DataFrame = object
CUDF_INSTALLED = False
# sklearn
try:
from sklearn.base import BaseEstimator

View File

@ -1,26 +1,27 @@
# coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-branches, too-many-lines, too-many-locals
# pylint: disable=too-many-public-methods
"""Core XGBoost Library."""
from __future__ import absolute_import
import collections
# pylint: disable=no-name-in-module,import-error
try:
from collections.abc import Mapping # Python 3
except ImportError:
from collections import Mapping # Python 2
from collections.abc import Mapping # Python 3
# pylint: enable=no-name-in-module,import-error
import math
import ctypes
import os
import re
import sys
import warnings
import json
import numpy as np
import scipy.sparse
from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str,
PANDAS_INSTALLED, DataTable, os_fspath, os_PathLike)
PANDAS_INSTALLED, DataTable,
CUDF_INSTALLED, CUDF_DataFrame,
os_fspath, os_PathLike)
from .libpath import find_lib_path
@ -131,8 +132,10 @@ def _load_lib():
os_error_list = []
for lib_path in lib_paths:
try:
# needed when the lib is linked with non-system-available dependencies
os.environ['PATH'] = os.pathsep.join(pathBackup + [os.path.dirname(lib_path)])
# needed when the lib is linked with non-system-available
# dependencies
os.environ['PATH'] = os.pathsep.join(
pathBackup + [os.path.dirname(lib_path)])
lib = ctypes.cdll.LoadLibrary(lib_path)
lib_success = True
except OSError as e:
@ -217,6 +220,51 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)
def _use_columnar_initializer(data):
'''Whether should we use columnar format initializer (pass data in as
json string). Currently cudf is the only valid option.'''
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
return True
return False
def _extract_interface_from_cudf(df, is_info):
'''This function should be upstreamed to cudf.'''
if not _use_columnar_initializer(df):
raise ValueError('Only cudf is supported for initializing as json ' +
'columnar format. For other libraries please ' +
'refer to specific API.')
def get_interface(obj):
return obj.mem.__cuda_array_interface__
array_interfaces = []
for col in df.columns:
data = df[col].data
array_interfaces.append(get_interface(data))
validity_masks = []
for col in df.columns:
if df[col].has_null_mask:
mask_interface = get_interface(df[col].nullmask)
mask_interface['null_count'] = df[col].null_count
validity_masks.append(mask_interface)
else:
validity_masks.append(False)
for i in range(len(df.columns)):
col_interface = array_interfaces[i]
mask_interface = validity_masks[i]
if mask_interface is not False:
col_interface['mask'] = mask_interface
if is_info:
array_interfaces = array_interfaces[0]
interfaces = bytes(json.dumps(array_interfaces, indent=2), 'utf-8')
return interfaces
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
@ -256,15 +304,18 @@ def _maybe_pandas_data(data, feature_names, feature_types):
def _maybe_pandas_label(label):
""" Extract internal data from pd.DataFrame for DMatrix label """
"""Extract internal data from pd.DataFrame for DMatrix label."""
if PANDAS_INSTALLED and isinstance(label, DataFrame):
if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns')
raise ValueError(
'DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes):
raise ValueError('DataFrame.dtypes for label must be int, float or bool')
if not all(dtype.name in PANDAS_DTYPE_MAPPER
for dtype in label_dtypes):
raise ValueError(
'DataFrame.dtypes for label must be int, float or bool')
label = label.values.astype('float')
# pd.Series can be passed to xgb as it is
@ -318,6 +369,22 @@ def _maybe_dt_array(array):
return array
def _check_data(data, missing):
'''The missing value applies only to np.ndarray.'''
is_invalid = (not isinstance(data, np.ndarray)) and (missing is not None)
is_invalid = is_invalid and not math.isnan(missing)
if is_invalid:
raise ValueError(
'missing value only applies to dense input, ' +
'e.g. `numpy.ndarray`.' +
' For a possibly sparse data type: ' + str(type(data)) +
' please remove missing values or set it to nan.' +
' Current missing value is set to: ' + str(missing))
if isinstance(data, list):
warnings.warn('Initializing DMatrix from List is deprecated.',
DeprecationWarning)
class DMatrix(object):
"""Data Matrix used in XGBoost.
@ -336,15 +403,16 @@ class DMatrix(object):
"""
Parameters
----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
dt.Frame/cudf.DataFrame
Data source of DMatrix.
When data is string or os.PathLike type, it represents the path libsvm format
txt file, or binary file that xgboost can read from.
label : list or numpy 1-D array, optional
Label of the training data.
missing : float, optional
Value in the data which needs to be present as a missing value. If
None, defaults to np.nan.
Value in the dense input data (e.g. `numpy.ndarray`) which needs
to be present as a missing value. If None, defaults to np.nan.
weight : list or numpy 1-D array , optional
Weight for each instance.
@ -375,6 +443,8 @@ class DMatrix(object):
self._feature_types = feature_types
return
_check_data(data, missing)
data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names,
feature_types)
@ -382,14 +452,11 @@ class DMatrix(object):
data, feature_names, feature_types = _maybe_dt_data(data,
feature_names,
feature_types)
label = _maybe_pandas_label(label)
label = _maybe_dt_array(label)
weight = _maybe_dt_array(weight)
if isinstance(data, list):
warnings.warn('Initializing DMatrix from List is deprecated.',
DeprecationWarning)
if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
@ -404,6 +471,8 @@ class DMatrix(object):
self._init_from_npy2d(data, missing, nthread)
elif isinstance(data, DataTable):
self._init_from_dt(data, nthread)
elif _use_columnar_initializer(data):
self._init_from_columnar(data)
else:
try:
csr = scipy.sparse.csr_matrix(data)
@ -415,11 +484,15 @@ class DMatrix(object):
if label is not None:
if isinstance(label, np.ndarray):
self.set_label_npy2d(label)
elif _use_columnar_initializer(label):
self.set_interface_info('label', label)
else:
self.set_label(label)
if weight is not None:
if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight)
elif _use_columnar_initializer(label):
self.set_interface_info('weight', weight)
else:
self.set_weight(weight)
@ -526,8 +599,19 @@ class DMatrix(object):
nthread))
self.handle = handle
def _init_from_columnar(self, df):
'''Initialize DMatrix from columnar memory format.
'''
interfaces = _extract_interface_from_cudf(df, False)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaces(interfaces,
ctypes.byref(handle)))
self.handle = handle
def __del__(self):
if hasattr(self, "handle") and self.handle is not None:
if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle))
self.handle = None
@ -593,6 +677,13 @@ class DMatrix(object):
c_data,
c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
'''Set info type peoperty into DMatrix.'''
interfaces = _extract_interface_from_cudf(data, True)
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interfaces))
def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
for numpy 2d array input
@ -732,7 +823,10 @@ class DMatrix(object):
margin: array like
Prediction margin of each datapoint
"""
self.set_float_info('base_margin', margin)
if _use_columnar_initializer(margin):
self.set_interface_info('base_margin', margin)
else:
self.set_float_info('base_margin', margin)
def set_group(self, group):
"""Set group size of DMatrix (used for ranking).
@ -742,9 +836,12 @@ class DMatrix(object):
group : array like
Group size of each group
"""
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
if _use_columnar_initializer(group):
self.set_interface_info('group', group)
else:
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
def get_label(self):
"""Get the label of the DMatrix.
@ -831,7 +928,8 @@ class DMatrix(object):
feature_names : list or None
"""
if self._feature_names is None:
self._feature_names = ['f{0}'.format(i) for i in range(self.num_col())]
self._feature_names = ['f{0}'.format(i)
for i in range(self.num_col())]
return self._feature_names
@property

View File

@ -1,4 +1,4 @@
// Copyright (c) 2014 by Contributors
// Copyright (c) 2014-2019 by Contributors
#include <xgboost/data.h>
#include <xgboost/learner.h>
@ -16,7 +16,7 @@
#include <string>
#include <memory>
#include "./c_api_error.h"
#include "c_api_error.h"
#include "../data/simple_csr_source.h"
#include "../common/math.h"
#include "../common/io.h"
@ -189,6 +189,16 @@ int XGDMatrixCreateFromDataIter(
API_END();
}
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(
char const* c_json_strs, DMatrixHandle* out) {
API_BEGIN();
std::string json_str {c_json_strs};
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(json_str);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
const unsigned* indices,
const bst_float* data,
@ -679,9 +689,9 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
}
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const bst_float* info,
xgboost::bst_ulong len) {
const char* field,
const bst_float* info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
@ -689,10 +699,20 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
API_END();
}
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
char const* field,
char const* interface_c_str) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, interface_c_str);
API_END();
}
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* info,
xgboost::bst_ulong len) {
const char* field,
const unsigned* info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
@ -771,7 +791,7 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) {
API_BEGIN();
CHECK_HANDLE();
*out = static_cast<size_t>(
*out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
API_END();
}

View File

@ -1,160 +0,0 @@
#ifndef XGBOOST_COMMON_BITFIELD_CUH_
#define XGBOOST_COMMON_BITFIELD_CUH_
#include <bitset>
#include <string>
#include <iostream>
#include <sstream>
#include <cinttypes>
#include "span.h"
namespace xgboost {
__forceinline__ __device__ unsigned long long AtomicOr(unsigned long long* address,
unsigned long long val) {
unsigned long long int old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val | assumed);
} while (assumed != old);
return old;
}
__forceinline__ __device__ unsigned long long AtomicAnd(unsigned long long* address,
unsigned long long val) {
unsigned long long int old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val & assumed);
} while (assumed != old);
return old;
}
/*!
* \brief A non-owning type with auxiliary methods defined for manipulating bits.
*/
struct BitField {
using value_type = uint64_t;
static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1UL; // force uint64_t
static_assert(kValueSize == 64, "uint64_t should be of 64 bits.");
struct Pos {
value_type int_pos {0};
value_type bit_pos {0};
};
common::Span<value_type> bits_;
public:
BitField() = default;
XGBOOST_DEVICE BitField(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitField(BitField const& other) : bits_{other.bits_} {}
static size_t ComputeStorageSize(size_t size) {
auto pos = ToBitPos(size);
if (size < kValueSize) {
return 1;
}
if (pos.bit_pos != 0) {
return pos.int_pos + 2;
} else {
return pos.int_pos + 1;
}
}
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
Pos pos_v;
if (pos == 0) {
return pos_v;
}
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
return pos_v;
}
__device__ BitField& operator|=(BitField const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(bits_.size(), rhs.bits_.size());
if (tid < min_size) {
bits_[tid] |= rhs.bits_[tid];
}
return *this;
}
__device__ BitField& operator&=(BitField const& rhs) {
size_t min_size = min(bits_.size(), rhs.bits_.size());
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < min_size) {
bits_[tid] &= rhs.bits_[tid];
}
return *this;
}
XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); }
__device__ void Set(value_type pos) {
Pos pos_v = ToBitPos(pos);
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << (kValueSize - pos_v.bit_pos - kOne);
static_assert(sizeof(unsigned long long int) == sizeof(value_type), "");
AtomicOr(reinterpret_cast<unsigned long long*>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
Pos pos_v = ToBitPos(pos);
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << (kValueSize - pos_v.bit_pos - kOne));
static_assert(sizeof(unsigned long long int) == sizeof(value_type), "");
AtomicAnd(reinterpret_cast<unsigned long long*>(&value), clear_bit);
}
XGBOOST_DEVICE bool Check(Pos pos_v) const {
value_type value = bits_[pos_v.int_pos];
value_type const test_bit = kOne << (kValueSize - pos_v.bit_pos - kOne);
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(value_type pos) const {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
friend std::ostream& operator<<(std::ostream& os, BitField field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (size_t i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitField::kValueSize> set(field.bits_[i]);
os << set << "\n";
}
return os;
}
};
inline void PrintDeviceBits(std::string name, BitField field) {
std::cout << "Bits: " << name << std::endl;
std::vector<BitField::value_type> h_field_bits(field.bits_.size());
thrust::copy(thrust::device_ptr<BitField::value_type>(field.bits_.data()),
thrust::device_ptr<BitField::value_type>(field.bits_.data() + field.bits_.size()),
h_field_bits.data());
BitField h_field;
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
std::cout << h_field;
}
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
std::cout << name << std::endl;
std::vector<int32_t> h_list(list.size());
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
thrust::device_ptr<int32_t>(list.data() + list.size()),
h_list.data());
for (auto v : h_list) {
std::cout << v << ", ";
}
std::cout << std::endl;
}
}
#endif // XGBOOST_COMMON_BITFIELD_CUH_

248
src/common/bitfield.h Normal file
View File

@ -0,0 +1,248 @@
/*!
* Copyright 2019 by Contributors
* \file bitfield.h
*/
#ifndef XGBOOST_COMMON_BITFIELD_H_
#define XGBOOST_COMMON_BITFIELD_H_
#include <algorithm>
#include <bitset>
#include <cinttypes>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "span.h"
namespace xgboost {
#if defined(__CUDACC__)
using BitFieldAtomicType = unsigned long long; // NOLINT
__forceinline__ __device__ BitFieldAtomicType AtomicOr(BitFieldAtomicType* address,
BitFieldAtomicType val) {
BitFieldAtomicType old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val | assumed);
} while (assumed != old);
return old;
}
__forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* address,
BitFieldAtomicType val) {
BitFieldAtomicType old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val & assumed);
} while (assumed != old);
return old;
}
#endif // defined(__CUDACC__)
/*!
* \brief A non-owning type with auxiliary methods defined for manipulating bits.
*
* \tparam Direction Whether the bits start from left or from right.
*/
template <typename VT, typename Direction>
struct BitFieldContainer {
using value_type = VT;
using pointer = value_type*;
static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type.
struct Pos {
value_type int_pos {0};
value_type bit_pos {0};
};
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
Pos pos_v;
if (pos == 0) {
return pos_v;
}
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
return pos_v;
}
public:
BitFieldContainer() = default;
XGBOOST_DEVICE BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
static size_t ComputeStorageSize(size_t size) {
auto pos = ToBitPos(size);
if (size < kValueSize) {
return 1;
}
if (pos.bit_pos != 0) {
return pos.int_pos + 2;
} else {
return pos.int_pos + 1;
}
}
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(bits_.size(), rhs.bits_.size());
if (tid < min_size) {
bits_[tid] |= rhs.bits_[tid];
}
return *this;
}
#else
BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
size_t min_size = std::min(bits_.size(), rhs.bits_.size());
for (size_t i = 0; i < min_size; ++i) {
bits_[i] |= rhs.bits_[i];
}
return *this;
}
#endif // #if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = min(bits_.size(), rhs.bits_.size());
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < min_size) {
bits_[tid] &= rhs.bits_[tid];
}
return *this;
}
#else
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = std::min(bits_.size(), rhs.bits_.size());
for (size_t i = 0; i < min_size; ++i) {
bits_[i] &= rhs.bits_[i];
}
return *this;
}
#endif // defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
__device__ void Set(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicOr(reinterpret_cast<BitFieldAtomicType*>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicAnd(reinterpret_cast<BitFieldAtomicType*>(&value), clear_bit);
}
#else
void Set(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
value |= set_bit;
}
void Clear(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
value &= clear_bit;
}
#endif // defined(__CUDA_ARCH__)
XGBOOST_DEVICE bool Check(Pos pos_v) const {
pos_v = Direction::Shift(pos_v);
value_type const value = bits_[pos_v.int_pos];
value_type const test_bit = kOne << pos_v.bit_pos;
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(value_type pos) const {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); }
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
friend std::ostream& operator<<(std::ostream& os, BitFieldContainer<VT, Direction> field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitFieldContainer<VT, Direction>::kValueSize> bset(field.bits_[i]);
os << bset << "\n";
}
return os;
}
};
// Bits start from left most bits (most significant bit).
template <typename VT>
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT>> {
using Container = BitFieldContainer<VT, LBitsPolicy<VT>>;
using Pos = typename Container::Pos;
using value_type = typename Container::value_type;
XGBOOST_DEVICE static Pos Shift(Pos pos) {
pos.bit_pos = Container::kValueSize - pos.bit_pos - Container::kOne;
return pos;
}
};
// Bits start from right most bit (least significant bit) of each entry, but integer index
// is from left to right.
template <typename VT>
struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
using Container = BitFieldContainer<VT, RBitsPolicy<VT>>;
using Pos = typename Container::Pos;
using value_type = typename Container::value_type;
XGBOOST_DEVICE static Pos Shift(Pos pos) {
return pos;
}
};
// Format: <Direction>BitField<size of underlying type>, underlying type must be unsigned.
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
using RBitField8 = BitFieldContainer<unsigned char, RBitsPolicy<unsigned char>>;
#if defined(__CUDACC__)
template <typename V, typename D>
inline void PrintDeviceBits(std::string name, BitFieldContainer<V, D> field) {
std::cout << "Bits: " << name << std::endl;
std::vector<typename BitFieldContainer<V, D>::value_type> h_field_bits(field.bits_.size());
thrust::copy(thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(field.bits_.data()),
thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(
field.bits_.data() + field.bits_.size()),
h_field_bits.data());
BitFieldContainer<V, D> h_field;
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
std::cout << h_field;
}
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
std::cout << name << std::endl;
std::vector<int32_t> h_list(list.size());
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
thrust::device_ptr<int32_t>(list.data() + list.size()),
h_list.data());
for (auto v : h_list) {
std::cout << v << ", ";
}
std::cout << std::endl;
}
#endif // defined(__CUDACC__)
} // namespace xgboost
#endif // XGBOOST_COMMON_BITFIELD_H_

View File

@ -51,12 +51,18 @@ inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file,
}
#endif
inline void CudaCheckPointerDevice(void* ptr) {
inline int32_t CudaGetPointerDevice(void* ptr) {
int32_t device = -1;
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int ptr_device = attr.device;
device = attr.device;
return device;
}
inline void CudaCheckPointerDevice(void* ptr) {
auto ptr_device = CudaGetPointerDevice(ptr);
int cur_device = -1;
cudaGetDevice(&cur_device);
dh::safe_cuda(cudaGetDevice(&cur_device));
CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device
<< "current device: " << cur_device;
}

View File

@ -365,8 +365,9 @@ struct HostDeviceVectorImpl {
void Shard(const GPUDistribution& distribution) {
if (distribution_ == distribution) { return; }
CHECK(distribution_.IsEmpty())
<< "This: " << distribution_.Devices().Size() << ", "
<< "Others: " << distribution.Devices().Size();
<< "Data resides on different GPUs: " << "ID: "
<< *(distribution_.Devices().begin()) << " and ID: "
<< *(distribution.Devices().begin());
distribution_ = distribution;
InitShards();
}
@ -412,6 +413,7 @@ struct HostDeviceVectorImpl {
perm_h_.Grant(access);
return;
}
std::lock_guard<std::mutex> lock(mutex_);
if (data_h_.size() != size_d_) { data_h_.resize(size_d_); }
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.LazySyncHost(access);

82
src/data/columnar.h Normal file
View File

@ -0,0 +1,82 @@
/*!
* Copyright 2019 by Contributors
* \file columnar.h
* \brief Basic structure holding a reference to arrow columnar data format.
*/
#ifndef XGBOOST_DATA_COLUMNAR_H_
#define XGBOOST_DATA_COLUMNAR_H_
#include <cinttypes>
#include <map>
#include <string>
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "../common/span.h"
#include "../common/bitfield.h"
namespace xgboost {
struct Columnar {
using mask_type = unsigned char;
using index_type = int32_t;
common::Span<float> data;
RBitField8 valid;
int32_t size;
int32_t null_count;
};
// Common errors in parsing columnar format.
struct ColumnarErrors {
static char const* Contigious() {
return "Memory should be contigious.";
}
static char const* TypestrFormat() {
return "`typestr` should be of format <endian><type><size>.";
}
// Not supported in Apache Arrow.
static char const* BigEndian() {
return "Big endian is not supported.";
}
static char const* Dimension(int32_t d) {
static std::string str;
str.clear();
str += "Only ";
str += std::to_string(d);
str += " dimensional array is valid.";
return str.c_str();
}
static char const* Version() {
return "Only version 1 of __cuda_array_interface__ is being supported.";
}
static char const* toFloat() {
return "Please convert the input into float32 first.";
}
static char const* toUInt() {
return "Please convert the Group into unsigned 32 bit integers first.";
}
static char const* ofType(std::string type) {
static std::string str;
str.clear();
str += " should be of ";
str += type;
str += " type.";
return str.c_str();
}
};
template <typename PtrType>
PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
if (obj.find("data") == obj.cend()) {
LOG(FATAL) << "Empty data passed in.";
}
auto p_data = reinterpret_cast<PtrType>(static_cast<size_t>(
get<Integer const>(
get<Array const>(
obj.at("data"))
.at(0))));
return p_data;
}
} // namespace xgboost
#endif // XGBOOST_DATA_COLUMNAR_H_

View File

@ -1,11 +1,12 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2019 by Contributors
* \file data.cc
*/
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <dmlc/registry.h>
#include <cstring>
#include "./sparse_page_writer.h"
#include "./simple_dmatrix.h"
#include "./simple_csr_source.h"
@ -110,7 +111,6 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
default: LOG(FATAL) << "Unknown data type" << dtype; \
} \
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
if (!std::strcmp(key, "root_index")) {
root_index_.resize(num);
@ -139,9 +139,17 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
}
}
#if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
LOG(FATAL) << "XGBoost version is not compiled with GPU support";
}
#endif // !defined(XGBOOST_USE_CUDA)
DMatrix* DMatrix::Load(const std::string& uri,
bool silent,
bool load_row_split,

86
src/data/data.cu Normal file
View File

@ -0,0 +1,86 @@
/*!
* Copyright 2019 by XGBoost Contributors
*
* \file data.cu
*/
#include "xgboost/data.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "columnar.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_arr = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr_obj = get<Object>(j_arr);
std::string key {c_key};
auto version = get<Integer const>(j_arr_obj.at("version"));
CHECK_EQ(version, 1) << ColumnarErrors::Version();
if (j_arr_obj.find("mask") != j_arr_obj.cend()) {
LOG(FATAL) << "Meta info " << key << " should be dense, found validity mask";
}
auto typestr = get<String const>(j_arr_obj.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
auto j_shape = get<Array const>(j_arr_obj.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
auto length = get<Integer const>(j_shape.at(0));
CHECK_GT(length, 0) << "Label set cannot be empty.";
if (j_arr_obj.find("strides") != j_arr_obj.cend()) {
auto strides = get<Array const>(j_arr_obj.at("strides"));
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
}
float* p_data = GetPtrFromArrayData<float*>(j_arr_obj);
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, p_data));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
thrust::device_ptr<float> p_src {p_data};
HostDeviceVector<float>* dst;
if (key == "root_index") {
LOG(FATAL) << "root index for columnar data is not supported.";
} else if (key == "label") {
dst = &labels_;
CHECK_EQ(typestr.at(1), 'f') << "Label"
<< ColumnarErrors::ofType("floating point");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
} else if (key == "weight") {
dst = &weights_;
CHECK_EQ(typestr.at(1), 'f') << "Weight"
<< ColumnarErrors::ofType("floating point");;
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
} else if (key == "base_margin") {
dst = &base_margin_;
CHECK_EQ(typestr.at(1), 'f') << "Base Margin"
<< ColumnarErrors::ofType("floating point");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
} else if (key == "group") {
CHECK_EQ(typestr.at(1), 'u') << "Group"
<< ColumnarErrors::ofType("unsigned 32 bit integers");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toUInt();
group_ptr_.resize(length + 1);
group_ptr_[0] = 0;
// Ranking is not performed on device.
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
return;
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
}
dst->Reshard(GPUDistribution(GPUSet::Range(ptr_device, 1)));
dst->Resize(length);
auto p_dst = thrust::device_pointer_cast(dst->DevicePointer(0));
thrust::copy(p_src, p_src + length, p_dst);
}
} // namespace xgboost

View File

@ -1,11 +1,14 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2019 by Contributors
* \file simple_csr_source.cc
*/
#include <dmlc/base.h>
#include <xgboost/logging.h>
#include <xgboost/json.h>
#include <limits>
#include "./simple_csr_source.h"
#include "simple_csr_source.h"
#include "columnar.h"
namespace xgboost {
namespace data {
@ -117,5 +120,143 @@ const SparsePage& SimpleCSRSource::Value() const {
return page_;
}
/*!
* Please be careful that, in official specification, the only three required fields are
* `shape', `version' and `typestr'. Any other is optional, including `data'. But here
* we have two additional requirements for input data:
*
* - `data' field is required, passing in an empty dataset is not accepted, as most (if
* not all) of our algorithms don't have test for empty dataset. An error is better
* than a crash.
*
* - `null_count' is required when `mask' is presented. We can compute `null_count'
* ourselves and copy the result back to host for memory allocation. But it's in the
* specification of Apache Arrow hence it should be readily available,
*
* Sample input:
* [
* {
* "shape": [
* 10
* ],
* "strides": [
* 4
* ],
* "data": [
* 30074864128,
* false
* ],
* "typestr": "<f4",
* "version": 1,
* "mask": {
* "shape": [
* 64
* ],
* "strides": [
* 1
* ],
* "data": [
* 30074864640,
* false
* ],
* "typestr": "|i1",
* "version": 1,
* "null_count": 1
* }
* }
* ]
*/
void SimpleCSRSource::CopyFrom(std::string const& cuda_interfaces_str) {
Json interfaces = Json::Load({cuda_interfaces_str.c_str(),
cuda_interfaces_str.size()});
std::vector<Json> const& columns = get<Array>(interfaces);
size_t n_columns = columns.size();
CHECK_GT(n_columns, 0);
std::vector<Columnar> foreign_cols(n_columns);
for (size_t i = 0; i < columns.size(); ++i) {
CHECK(IsA<Object>(columns[i]));
auto const& column = get<Object const>(columns[i]);
auto version = get<Integer const>(column.at("version"));
CHECK_EQ(version, 1) << ColumnarErrors::Version();
// Find null mask (validity mask) field
// Mask object is also an array interface, but with different requirements.
// TODO(trivialfis): Abstract this into a class that accept a json
// object and turn it into an array (for cupy and numba).
common::Span<RBitField8::value_type> s_mask;
int32_t null_count {0};
if (column.find("mask") != column.cend()) {
auto const& j_mask = get<Object const>(column.at("mask"));
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
auto j_shape = get<Array const>(j_mask.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(j_shape.front()) % 8, 0) <<
"Length of validity map must be a multiple of 8 bytes.";
int64_t size = get<Integer>(j_shape.at(0)) *
sizeof(unsigned char) / sizeof(RBitField8::value_type);
s_mask = {p_mask, size};
auto typestr = get<String const>(j_mask.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
CHECK_EQ(typestr.at(1), 'i') << "mask" << ColumnarErrors::ofType("unsigned char");
CHECK_EQ(typestr.at(2), '1') << "mask" << ColumnarErrors::toUInt();
CHECK(j_mask.find("null_count") != j_mask.cend()) <<
"Column with null mask must include null_count as "
"part of mask object for XGBoost.";
null_count = get<Integer const>(j_mask.at("null_count"));
}
// Find data field
if (column.find("data") == column.cend()) {
LOG(FATAL) << "Empty dataset passed in.";
}
auto typestr = get<String const>(column.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
CHECK_EQ(typestr.at(1), 'f') << "data" << ColumnarErrors::ofType("floating point");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
auto j_shape = get<Array const>(column.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
if (column.find("strides") != column.cend()) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
}
auto length = get<Integer const>(j_shape.at(0));
float* p_data = GetPtrFromArrayData<float*>(column);
common::Span<float> s_data {p_data, length};
foreign_cols[i].data = s_data;
foreign_cols[i].valid = RBitField8(s_mask);
foreign_cols[i].size = s_data.size();
foreign_cols[i].null_count = null_count;
}
info.num_col_ = n_columns;
info.num_row_ = foreign_cols[0].size;
for (size_t i = 0; i < n_columns; ++i) {
CHECK_EQ(foreign_cols[0].size, foreign_cols[i].size);
info.num_nonzero_ += foreign_cols[i].data.size() - foreign_cols[i].null_count;
}
this->FromDeviceColumnar(foreign_cols);
}
#if !defined(XGBOOST_USE_CUDA)
void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
LOG(FATAL) << "XGBoost version is not compiled with GPU support";
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace data
} // namespace xgboost

View File

@ -0,0 +1,117 @@
/*!
* Copyright 2019 by XGBoost Contributors
*
* \file simple_csr_source.cuh
* \brief An extension for the simple CSR source in-memory data structure to accept
* foreign columnar.
*/
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/scan.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <vector>
#include <algorithm>
#include "simple_csr_source.h"
#include "columnar.h"
#include "../common/bitfield.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace data {
template <size_t kBlockThreads>
__global__ void CountValidKernel(common::Span<Columnar const> columns,
int32_t const n_rows,
common::Span<size_t> offsets) {
// One block for a column
auto const bid = blockIdx.x;
auto const tid = threadIdx.x;
if (bid >= columns.size()) {
return;
}
RBitField8 const mask = columns[bid].valid;
for (auto r = tid; r < n_rows; r += kBlockThreads) {
if (mask.Data() == nullptr || mask.Check(r)) {
atomicAdd(reinterpret_cast<BitFieldAtomicType*>(&offsets[r+1]),
static_cast<BitFieldAtomicType>(1));
}
}
}
__global__ void CreateCSRKernel(Columnar const column,
int32_t colid,
common::Span<size_t> offsets,
common::Span<Entry> out_data) {
auto tid = threadIdx.x + blockDim.x * blockIdx.x;
if (column.size <= tid) {
return;
}
if (column.valid.Data() == nullptr || column.valid.Check(tid)) {
int32_t oid = offsets[tid];
out_data[oid].fvalue = column.data[tid];
out_data[oid].index = colid;
offsets[tid] += 1;
}
}
void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
uint64_t const n_cols = cols.size();
uint64_t const n_rows = cols[0].size;
auto ptr = cols[0].data.data();
int32_t device = dh::CudaGetPointerDevice(ptr);
CHECK_NE(device, -1);
for (int32_t i = 1; i < n_cols; ++i) {
auto ptr = cols[i].data.data();
int32_t ptr_device = dh::CudaGetPointerDevice(ptr);
CHECK_EQ(device, ptr_device)
<< "GPU ID at 0^th column: " << device << ", "
<< "GPU ID at column " << i << ": " << ptr_device;
}
dh::safe_cuda(cudaSetDevice(device));
GPUSet devices = GPUSet::Range(device, 1);
page_.offset.Reshard(GPUDistribution(devices));
page_.offset.Resize(info.num_row_ + 1);
page_.data.Reshard(GPUDistribution(devices));
page_.data.Resize(info.num_nonzero_);
auto s_data = page_.data.DeviceSpan(device);
auto s_offsets = page_.offset.DeviceSpan(device);
CHECK_EQ(s_offsets.size(), n_rows + 1);
int32_t constexpr kThreads = 256;
dh::device_vector<Columnar> d_cols(cols);
auto s_d_cols = dh::ToSpan(d_cols);
dh::safe_cuda(cudaMemset(s_offsets.data(), 0, sizeof(int32_t) * (n_rows + 1)));
CountValidKernel<kThreads><<<n_cols, kThreads>>>(s_d_cols, n_rows, s_offsets);
thrust::device_ptr<size_t> p_offsets(s_offsets.data());
CHECK_GE(s_offsets.size(), n_rows + 1);
thrust::inclusive_scan(p_offsets, p_offsets + n_rows + 1, p_offsets);
// Created for building csr matrix, where we need to change index
// after processing each column.
dh::device_vector<size_t> tmp_offset(page_.offset.Size());
thrust::copy(p_offsets, p_offsets + n_rows + 1, tmp_offset.begin());
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
for (size_t col = 0; col < n_cols; ++col) {
CreateCSRKernel<<<kBlocks, kThreads>>>(d_cols[col], col, dh::ToSpan(tmp_offset), s_data);
}
}
} // namespace data
} // namespace xgboost

View File

@ -10,9 +10,12 @@
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <vector>
#include <algorithm>
#include <algorithm>
#include <string>
#include <vector>
#include "columnar.h"
namespace xgboost {
namespace data {
@ -27,7 +30,6 @@ namespace data {
*/
class SimpleCSRSource : public DataSource<SparsePage> {
public:
// public data members
// MetaInfo info; // inheritated from DataSource
SparsePage page_;
/*! \brief default constructor */
@ -47,6 +49,11 @@ class SimpleCSRSource : public DataSource<SparsePage> {
* \param info The additional information reflected in the parser.
*/
void CopyFrom(dmlc::Parser<uint32_t>* src);
/*!
* \brief copy content of data from foreign **GPU** columnar buffer.
* \param interfaces_str JSON representation of cuda array interfaces.
*/
void CopyFrom(std::string const& cuda_interfaces_str);
/*!
* \brief Load data from binary stream.
* \param fi the pointer to load data from.
@ -67,6 +74,11 @@ class SimpleCSRSource : public DataSource<SparsePage> {
static const int kMagic = 0xffffab01;
private:
/*!
* \brief copy content of data from foreign GPU columnar buffer.
* \param cols foreign columns data buffer.
*/
void FromDeviceColumnar(std::vector<Columnar> cols);
/*! \brief internal variable, used to support iterator interface */
bool at_first_{true};
};

View File

@ -22,9 +22,6 @@
namespace xgboost {
BitField::value_type constexpr BitField::kValueSize;
BitField::value_type constexpr BitField::kOne;
size_t FeatureInteractionConstraint::Features() const {
return d_sets_ptr_.size() - 1;
}
@ -51,7 +48,7 @@ void FeatureInteractionConstraint::Configure(
}
n_sets_ = h_feature_constraints.size();
size_t const n_feat_storage = BitField::ComputeStorageSize(n_features);
size_t const n_feat_storage = LBitField64::ComputeStorageSize(n_features);
if (n_feat_storage == 0 && n_features != 0) {
LOG(FATAL) << "Wrong storage size, n_features: " << n_features;
}
@ -70,13 +67,13 @@ void FeatureInteractionConstraint::Configure(
node_constraints_.resize(n_nodes);
node_constraints_storage_.resize(n_nodes);
for (auto& n : node_constraints_storage_) {
n.resize(BitField::ComputeStorageSize(n_features));
n.resize(LBitField64::ComputeStorageSize(n_features));
}
for (size_t i = 0; i < node_constraints_storage_.size(); ++i) {
auto span = dh::ToSpan(node_constraints_storage_[i]);
node_constraints_[i] = BitField(span);
node_constraints_[i] = LBitField64(span);
}
s_node_constraints_ = common::Span<BitField>(node_constraints_.data(),
s_node_constraints_ = common::Span<LBitField64>(node_constraints_.data(),
node_constraints_.size());
// Represent constraints as CSR format, flatten is the value vector,
@ -131,14 +128,14 @@ void FeatureInteractionConstraint::Configure(
s_sets_ = dh::ToSpan(d_sets_);
s_sets_ptr_ = dh::ToSpan(d_sets_ptr_);
d_feature_buffer_storage_.resize(BitField::ComputeStorageSize(n_features));
d_feature_buffer_storage_.resize(LBitField64::ComputeStorageSize(n_features));
feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_);
// --- Initialize result buffers.
output_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_));
input_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features));
input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_));
output_buffer_bits_storage_.resize(LBitField64::ComputeStorageSize(n_features));
output_buffer_bits_ = LBitField64(dh::ToSpan(output_buffer_bits_storage_));
input_buffer_bits_storage_.resize(LBitField64::ComputeStorageSize(n_features));
input_buffer_bits_ = LBitField64(dh::ToSpan(input_buffer_bits_storage_));
result_buffer_.resize(n_features);
s_result_buffer_ = dh::ToSpan(result_buffer_);
}
@ -156,7 +153,7 @@ void FeatureInteractionConstraint::Reset() {
}
__global__ void ClearBuffersKernel(
BitField result_buffer_output, BitField result_buffer_input) {
LBitField64 result_buffer_output, LBitField64 result_buffer_input) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < result_buffer_output.Size()) {
result_buffer_output.Clear(tid);
@ -185,7 +182,7 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size());
auto p_result_buffer = result_buffer_.data();
BitField node_constraints = s_node_constraints_[node_id];
LBitField64 node_constraints = s_node_constraints_[node_id];
thrust::device_ptr<int32_t> const out_end = thrust::copy_if(
thrust::device,
@ -201,16 +198,16 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
}
__global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input,
BitField result_buffer_input) {
LBitField64 result_buffer_input) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < feature_list_input.size()) {
result_buffer_input.Set(feature_list_input[tid]);
}
}
__global__ void QueryFeatureListKernel(BitField node_constraints,
BitField result_buffer_input,
BitField result_buffer_output) {
__global__ void QueryFeatureListKernel(LBitField64 node_constraints,
LBitField64 result_buffer_input,
LBitField64 result_buffer_output) {
result_buffer_output |= node_constraints;
result_buffer_output &= result_buffer_input;
}
@ -223,7 +220,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
ClearBuffers();
BitField node_constraints = s_node_constraints_[nid];
LBitField64 node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256;
@ -237,7 +234,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size());
BitField local_result_buffer = output_buffer_bits_;
LBitField64 local_result_buffer = output_buffer_bits_;
thrust::device_ptr<int32_t> const out_end = thrust::copy_if(
thrust::device,
@ -257,7 +254,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
// Find interaction sets for each feature, then store all features in
// those sets in a buffer.
__global__ void RestoreFeatureListFromSetsKernel(
BitField feature_buffer,
LBitField64 feature_buffer,
int32_t fid,
common::Span<int32_t> feature_interactions,
@ -283,11 +280,11 @@ __global__ void RestoreFeatureListFromSetsKernel(
}
}
__global__ void InteractionConstraintSplitKernel(BitField feature,
__global__ void InteractionConstraintSplitKernel(LBitField64 feature,
int32_t feature_id,
BitField node,
BitField left,
BitField right) {
LBitField64 node,
LBitField64 left,
LBitField64 right) {
auto tid = threadIdx.x + blockDim.x * blockIdx.x;
if (tid > node.Size()) {
return;
@ -324,9 +321,9 @@ void FeatureInteractionConstraint::Split(
CHECK_LT(right_id, s_node_constraints_.size());
CHECK_NE(s_node_constraints_.size(), 0);
BitField node = s_node_constraints_[node_id];
BitField left = s_node_constraints_[left_id];
BitField right = s_node_constraints_[right_id];
LBitField64 node = s_node_constraints_[node_id];
LBitField64 left = s_node_constraints_[left_id];
LBitField64 right = s_node_constraints_[right_id];
dim3 const block3(16, 64, 1);
dim3 const grid3(common::DivRoundUp(n_sets_, 16),

View File

@ -13,7 +13,7 @@
#include "param.h"
#include "../common/span.h"
#include "../common/bitfield.cuh"
#include "../common/bitfield.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
@ -115,25 +115,25 @@ struct FeatureInteractionConstraint {
// Allowed features attached to each node, have n_nodes bitfields,
// each of size n_features.
std::vector<dh::device_vector<BitField::value_type>> node_constraints_storage_;
std::vector<BitField> node_constraints_;
common::Span<BitField> s_node_constraints_;
std::vector<dh::device_vector<LBitField64::value_type>> node_constraints_storage_;
std::vector<LBitField64> node_constraints_;
common::Span<LBitField64> s_node_constraints_;
// buffer storing return feature list from Query, of size n_features.
dh::device_vector<int32_t> result_buffer_;
common::Span<int32_t> s_result_buffer_;
// Temp buffers, one bit for each possible feature.
dh::device_vector<BitField::value_type> output_buffer_bits_storage_;
BitField output_buffer_bits_;
dh::device_vector<BitField::value_type> input_buffer_bits_storage_;
BitField input_buffer_bits_;
dh::device_vector<LBitField64::value_type> output_buffer_bits_storage_;
LBitField64 output_buffer_bits_;
dh::device_vector<LBitField64::value_type> input_buffer_bits_storage_;
LBitField64 input_buffer_bits_;
/*
* Combined features from all interaction sets that one feature belongs to.
* For an input with [[0, 1], [1, 2]], the feature 1 belongs to sets {0, 1}
*/
dh::device_vector<BitField::value_type> d_feature_buffer_storage_;
BitField feature_buffer_; // of Size n features.
dh::device_vector<LBitField64::value_type> d_feature_buffer_storage_;
LBitField64 feature_buffer_; // of Size n features.
// Clear out all temp buffers except for `feature_buffer_', which is
// handled in `Split'.

View File

@ -0,0 +1,92 @@
/*!
* Copyright 2019 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/bitfield.h"
namespace xgboost {
TEST(BitField, Check) {
{
std::vector<LBitField64::value_type> storage(4, 0);
storage[2] = 2;
auto bits = LBitField64({storage.data(),
static_cast<typename common::Span<LBitField64::value_type>::index_type>(
storage.size())});
size_t true_bit = 190;
for (size_t i = true_bit + 1; i < bits.Size(); ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = 0; i < true_bit; ++i) {
ASSERT_FALSE(bits.Check(i));
}
}
{
std::vector<RBitField8::value_type> storage(4, 0);
storage[2] = 1 << 3;
auto bits = RBitField8({storage.data(),
static_cast<typename common::Span<RBitField8::value_type>::index_type>(
storage.size())});
size_t true_bit = 19;
for (size_t i = 0; i < true_bit; ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = true_bit + 1; i < bits.Size(); ++i) {
ASSERT_FALSE(bits.Check(i));
}
}
}
template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
void TestBitFieldSet(typename BitFieldT::value_type res, size_t index, size_t true_bit) {
using IndexT = typename common::Span<VT>::index_type;
std::vector<VT> storage(4, 0);
auto bits = BitFieldT({storage.data(), static_cast<IndexT>(storage.size())});
bits.Set(true_bit);
for (size_t i = 0; i < true_bit; ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = true_bit + 1; i < storage.size() * BitFieldT::kValueSize; ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_EQ(storage[index], res);
}
TEST(BitField, Set) {
{
TestBitFieldSet<LBitField64>(2, 2, 190);
}
{
TestBitFieldSet<RBitField8>(1 << 3, 2, 19);
}
}
template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
void TestBitFieldClear(size_t clear_bit) {
using IndexT = typename common::Span<VT>::index_type;
std::vector<VT> storage(4, 0);
auto bits = BitFieldT({storage.data(), static_cast<IndexT>(storage.size())});
bits.Set(clear_bit);
bits.Clear(clear_bit);
ASSERT_FALSE(bits.Check(clear_bit));
}
TEST(BitField, Clear) {
{
TestBitFieldClear<LBitField64>(190);
}
{
TestBitFieldClear<RBitField8>(19);
}
}
} // namespace xgboost

View File

@ -5,56 +5,55 @@
#include <thrust/copy.h>
#include <thrust/device_vector.h>
#include <vector>
#include "../../../src/common/bitfield.cuh"
#include "../../../src/common/bitfield.h"
#include "../../../src/common/device_helpers.cuh"
namespace xgboost {
__global__ void TestSetKernel(BitField bits) {
__global__ void TestSetKernel(LBitField64 bits) {
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < bits.Size()) {
bits.Set(tid);
}
}
TEST(BitField, Set) {
dh::device_vector<BitField::value_type> storage;
TEST(BitField, GPU_Set) {
dh::device_vector<LBitField64::value_type> storage;
uint32_t constexpr kBits = 128;
storage.resize(128);
auto bits = BitField(dh::ToSpan(storage));
auto bits = LBitField64(dh::ToSpan(storage));
TestSetKernel<<<1, kBits>>>(bits);
std::vector<BitField::value_type> h_storage(storage.size());
std::vector<LBitField64::value_type> h_storage(storage.size());
thrust::copy(storage.begin(), storage.end(), h_storage.begin());
BitField outputs {
common::Span<BitField::value_type>{h_storage.data(),
LBitField64 outputs {
common::Span<LBitField64::value_type>{h_storage.data(),
h_storage.data() + h_storage.size()}};
for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i));
}
}
__global__ void TestOrKernel(BitField lhs, BitField rhs) {
__global__ void TestOrKernel(LBitField64 lhs, LBitField64 rhs) {
lhs |= rhs;
}
TEST(BitField, And) {
TEST(BitField, GPU_And) {
uint32_t constexpr kBits = 128;
dh::device_vector<BitField::value_type> lhs_storage(kBits);
dh::device_vector<BitField::value_type> rhs_storage(kBits);
auto lhs = BitField(dh::ToSpan(lhs_storage));
auto rhs = BitField(dh::ToSpan(rhs_storage));
dh::device_vector<LBitField64::value_type> lhs_storage(kBits);
dh::device_vector<LBitField64::value_type> rhs_storage(kBits);
auto lhs = LBitField64(dh::ToSpan(lhs_storage));
auto rhs = LBitField64(dh::ToSpan(rhs_storage));
thrust::fill(lhs_storage.begin(), lhs_storage.end(), 0UL);
thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast<BitField::value_type>(0UL));
thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast<LBitField64::value_type>(0UL));
TestOrKernel<<<1, kBits>>>(lhs, rhs);
std::vector<BitField::value_type> h_storage(lhs_storage.size());
std::vector<LBitField64::value_type> h_storage(lhs_storage.size());
thrust::copy(lhs_storage.begin(), lhs_storage.end(), h_storage.begin());
BitField outputs {{h_storage.data(), h_storage.data() + h_storage.size()}};
LBitField64 outputs {{h_storage.data(), h_storage.data() + h_storage.size()}};
for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i));
}
}
} // namespace xgboost

View File

@ -0,0 +1,57 @@
/*! Copyright 2019 by Contributors */
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <thrust/device_vector.h>
#include "../../../src/common/device_helpers.cuh"
namespace xgboost {
TEST(MetaInfo, FromInterface) {
cudaSetDevice(0);
constexpr size_t kRows = 16;
thrust::device_vector<float> d_data(kRows);
for (size_t i = 0; i < d_data.size(); ++i) {
d_data[i] = i * 2.0;
}
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
auto p_d_data = dh::Raw(d_data);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
column["data"] = j_data;
std::stringstream ss;
Json::Dump(column, &ss);
std::string str = ss.str();
MetaInfo info;
info.SetInfo("label", str.c_str());
auto const& h_label = info.labels_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_label[i], d_data[i]);
}
info.SetInfo("weight", str.c_str());
auto const& h_weight = info.weights_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_weight[i], d_data[i]);
}
info.SetInfo("base_margin", str.c_str());
auto const& h_base_margin = info.base_margin_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_base_margin[i], d_data[i]);
}
}
} // namespace xgboost

View File

@ -1,10 +1,15 @@
// Copyright by Contributors
#include <xgboost/data.h>
#include <gtest/gtest.h>
#include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include "../../../src/data/simple_csr_source.h"
#include "../helpers.h"
namespace xgboost {
TEST(SimpleCSRSource, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm";
@ -33,3 +38,4 @@ TEST(SimpleCSRSource, SaveLoadBinary) {
delete dmat;
delete dmat_read;
}
} // namespace xgboost

View File

@ -0,0 +1,216 @@
// Copyright (c) 2019 by Contributors
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <thrust/device_vector.h>
#include <memory>
#include "../../../src/common/bitfield.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/data/simple_csr_source.h"
namespace xgboost {
TEST(SimpleCSRSource, FromColumnarDense) {
constexpr size_t kRows = 16;
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
thrust::device_vector<float> d_data(kRows);
for (size_t i = 0; i < d_data.size(); ++i) {
d_data[i] = i * 2.0;
}
auto p_d_data = dh::Raw(d_data);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
column["data"] = j_data;
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
Json column_arr {Array{std::vector<Json>{column}}};
std::stringstream ss;
Json::Dump(column_arr, &ss);
std::string str = ss.str();
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(str.c_str());
auto const& data = source->page_.data.HostVector();
auto const& offset = source->page_.offset.HostVector();
for (size_t i = 0; i < kRows; ++i) {
auto e = data[i];
ASSERT_NEAR(e.fvalue, i * 2.0, kRtEps);
ASSERT_EQ(e.index, 0); // feature 0
}
ASSERT_EQ(offset.back(), 16);
for (size_t i = 0; i < kRows + 1; ++i) {
ASSERT_EQ(offset[i], i);
}
}
TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
// In this test we construct a data storage similar to cudf
constexpr size_t kRows = 102;
constexpr size_t kCols = 24;
constexpr size_t kMissingRows = 3;
std::vector<Json> v_columns (kCols);
std::vector<dh::device_vector<float>> columns_data(kCols);
std::vector<dh::device_vector<unsigned char>> column_bitfields(kCols);
unsigned char constexpr kUCOne = 1;
for (size_t i = 0; i < kCols; ++i) {
auto& col = v_columns[i];
col = Object();
auto& data = columns_data[i];
data.resize(kRows);
thrust::sequence(data.begin(), data.end(), 0);
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(cudaGetLastError());
ASSERT_EQ(data.size(), kRows);
auto p_d_data = raw_pointer_cast(data.data());
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
col["data"] = j_data;
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
col["shape"] = Array(j_shape);
col["version"] = Integer(static_cast<Integer::Int>(1));
col["typestr"] = String("<f4");
// Construct the mask object.
col["mask"] = Object();
auto& j_mask = col["mask"];
auto& mask_storage = column_bitfields[i];
mask_storage.resize(16); // 16 bytes
mask_storage[0] = ~(kUCOne << 2); // 3^th row is missing
mask_storage[1] = ~(kUCOne << 3); // 12^th row is missing
size_t last_ind = 12;
mask_storage[last_ind] = ~(kUCOne << 5);
std::set<size_t> missing_row_index {0, 1, last_ind};
for (size_t i = 0; i < mask_storage.size(); ++i) {
if (missing_row_index.find(i) == missing_row_index.cend()) {
// all other rows are valid
mask_storage[i] = ~0;
}
}
j_mask["data"] = std::vector<Json>{
Json(Integer(reinterpret_cast<Integer::Int>(mask_storage.data().get()))),
Json(Boolean(false))};
j_mask["shape"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(16)))});
j_mask["typestr"] = String("|i1");
j_mask["null_count"] = Json(Integer(static_cast<Integer::Int>(kMissingRows)));
}
Json column_arr {Array(v_columns)};
std::stringstream ss;
Json::Dump(column_arr, &ss);
std::string str = ss.str();
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(str.c_str());
auto const& data = source->page_.data.HostVector();
auto const& offset = source->page_.offset.HostVector();
ASSERT_EQ(offset.size(), kRows + 1);
for (size_t i = 1; i < offset.size(); ++i) {
for (size_t j = offset[i-1]; j < offset[i]; ++j) {
ASSERT_EQ(data[j].index, j % kCols);
ASSERT_NEAR(data[j].fvalue, i - 1, kRtEps);
}
}
}
TEST(SimpleCSRSource, FromColumnarSparse) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 2;
unsigned char constexpr kUCOne = 1;
std::vector<dh::device_vector<float>> columns_data(kCols);
std::vector<dh::device_vector<unsigned char>> column_bitfields(kCols);
{
// column 0
auto& mask = column_bitfields[0];
mask.resize(8);
for (size_t j = 0; j < mask.size(); ++j) {
mask[j] = ~0;
}
mask[0] = ~(kUCOne << 2);
}
{
// column 1
auto& mask = column_bitfields[1];
mask.resize(8);
for (size_t j = 0; j < mask.size(); ++j) {
mask[j] = ~0;
}
mask[2] = ~(kUCOne << 3);
}
for (size_t c = 0; c < kCols; ++c) {
columns_data[c].resize(kRows);
thrust::sequence(columns_data[c].begin(), columns_data[c].end(), 0);
}
std::vector<Json> j_columns(kCols);
for (size_t c = 0; c < kCols; ++c) {
auto& column = j_columns[c];
column = Object();
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
auto p_d_data = raw_pointer_cast(columns_data[c].data());
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
column["data"] = j_data;
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
column["mask"] = Object();
auto& j_mask = column["mask"];
j_mask["data"] = std::vector<Json>{
Json(Integer(reinterpret_cast<Integer::Int>(column_bitfields[c].data().get()))),
Json(Boolean(false))};
j_mask["shape"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(8)))});
j_mask["typestr"] = String("|i1");
j_mask["null_count"] = Json(Integer(static_cast<Integer::Int>(1)));
}
Json column_arr {Array(j_columns)};
std::stringstream ss;
Json::Dump(column_arr, &ss);
std::string str = ss.str();
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(str.c_str());
auto const& data = source->page_.data.HostVector();
auto const& offset = source->page_.offset.HostVector();
ASSERT_EQ(offset.size(), kRows + 1);
ASSERT_EQ(data[4].index, 1);
ASSERT_EQ(data[4].fvalue, 2);
ASSERT_EQ(data[37].index, 0);
ASSERT_EQ(data[37].fvalue, 19);
}
} // namespace xgboost

View File

@ -16,7 +16,7 @@ namespace xgboost {
namespace {
struct FConstraintWrapper : public FeatureInteractionConstraint {
common::Span<BitField> GetNodeConstraints() {
common::Span<LBitField64> GetNodeConstraints() {
return FeatureInteractionConstraint::s_node_constraints_;
}
FConstraintWrapper(tree::TrainParam param, int32_t n_features) :
@ -44,13 +44,13 @@ tree::TrainParam GetParameter() {
return param;
}
void CompareBitField(BitField d_field, std::set<uint32_t> positions) {
std::vector<BitField::value_type> h_field_storage(d_field.bits_.size());
thrust::copy(thrust::device_ptr<BitField::value_type>(d_field.bits_.data()),
thrust::device_ptr<BitField::value_type>(
void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
std::vector<LBitField64::value_type> h_field_storage(d_field.bits_.size());
thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_field.bits_.data()),
thrust::device_ptr<LBitField64::value_type>(
d_field.bits_.data() + d_field.bits_.size()),
h_field_storage.data());
BitField h_field;
LBitField64 h_field;
h_field.bits_ = {h_field_storage.data(), h_field_storage.data() + h_field_storage.size()};
for (size_t i = 0; i < h_field.Size(); ++i) {
@ -71,14 +71,14 @@ TEST(FeatureInteractionConstraint, Init) {
tree::TrainParam param = GetParameter();
FConstraintWrapper constraints(param, kFeatures);
ASSERT_EQ(constraints.Features(), kFeatures);
common::Span<BitField> s_nodes_constraints = constraints.GetNodeConstraints();
for (BitField const& d_node : s_nodes_constraints) {
std::vector<BitField::value_type> h_node_storage(d_node.bits_.size());
thrust::copy(thrust::device_ptr<BitField::value_type>(d_node.bits_.data()),
thrust::device_ptr<BitField::value_type>(
common::Span<LBitField64> s_nodes_constraints = constraints.GetNodeConstraints();
for (LBitField64 const& d_node : s_nodes_constraints) {
std::vector<LBitField64::value_type> h_node_storage(d_node.bits_.size());
thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_node.bits_.data()),
thrust::device_ptr<LBitField64::value_type>(
d_node.bits_.data() + d_node.bits_.size()),
h_node_storage.data());
BitField h_node;
LBitField64 h_node;
h_node.bits_ = {h_node_storage.data(), h_node_storage.data() + h_node_storage.size()};
// no feature is attached to node.
for (size_t i = 0; i < h_node.Size(); ++i) {
@ -108,7 +108,7 @@ TEST(FeatureInteractionConstraint, Init) {
}
{
// Test having more than 1 BitField::value_type
// Test having more than 1 LBitField64::value_type
int32_t constexpr kFeatures = 129;
tree::TrainParam param = GetParameter();
param.interaction_constraints = R"([[0, 1, 3], [3, 5, 128], [127, 128]])";
@ -129,7 +129,7 @@ TEST(FeatureInteractionConstraint, Split) {
FConstraintWrapper constraints(param, kFeatures);
{
BitField d_node[3];
LBitField64 d_node[3];
constraints.Split(0, /*feature_id=*/1, 1, 2);
for (size_t nid = 0; nid < 3; ++nid) {
d_node[nid] = constraints.GetNodeConstraints()[nid];
@ -139,7 +139,7 @@ TEST(FeatureInteractionConstraint, Split) {
}
{
BitField d_node[5];
LBitField64 d_node[5];
constraints.Split(1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
for (auto nid : {1, 3, 4}) {
d_node[nid] = constraints.GetNodeConstraints()[nid];

View File

@ -0,0 +1,44 @@
import numpy as np
import xgboost as xgb
import sys
import pytest
sys.path.append("tests/python")
import testing as tm
pytestmark = pytest.mark.skipif(**tm.no_cudf())
class TestFromColumnar:
'''Tests for constructing DMatrix from data structure conforming Apache
Arrow specification.'''
@pytest.mark.skipif(**tm.no_cudf())
def test_from_cudf():
'''Test constructing DMatrix from cudf'''
import cudf
import pandas as pd
kRows = 80
kCols = 2
na = np.random.randn(kRows, kCols).astype(np.float32)
na[3, 1] = np.NAN
na[5, 0] = np.NAN
pa = pd.DataFrame(na)
np_label = np.random.randn(kRows).astype(np.float32)
pa_label = pd.DataFrame(np_label)
names = []
for i in range(0, kCols):
names.append(str(i))
pa.columns = names
cd: cudf.DataFrame = cudf.from_pandas(pa)
cd_label: cudf.DataFrame = cudf.from_pandas(pa_label)
dtrain = xgb.DMatrix(cd, label=cd_label)
assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows

View File

@ -1,5 +1,6 @@
# coding: utf-8
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED, DASK_INSTALLED
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
from xgboost.compat import CUDF_INSTALLED, DASK_INSTALLED
def no_sklearn():
@ -31,3 +32,8 @@ def no_matplotlib():
except ImportError:
return {'condition': True,
'reason': reason}
def no_cudf():
return {'condition': not CUDF_INSTALLED,
'reason': 'CUDF is not installed'}