Cudf support. (#4745)

* Initial support for cudf integration.

* Add two C APIs for consuming data and metainfo.

* Add CopyFrom for SimpleCSRSource as a generic function to consume the data.

* Add FromDeviceColumnar for consuming device data.

* Add new MetaInfo::SetInfo for consuming label, weight etc.
This commit is contained in:
Jiaming Yuan 2019-08-19 00:51:40 -04:00 committed by Rory Mitchell
parent ab357dd41c
commit 9700776597
26 changed files with 1385 additions and 287 deletions

View File

@ -10,17 +10,17 @@
#include <dmlc/base.h> #include <dmlc/base.h>
#include <dmlc/data.h> #include <dmlc/data.h>
#include <rabit/rabit.h> #include <rabit/rabit.h>
#include <cstring> #include <xgboost/base.h>
#include <memory> #include <memory>
#include <numeric> #include <numeric>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./base.h"
#include "../../src/common/span.h" #include "../../src/common/span.h"
#include "../../src/common/group_data.h" #include "../../src/common/group_data.h"
#include "../../src/common/host_device_vector.h" #include "../../src/common/host_device_vector.h"
namespace xgboost { namespace xgboost {
@ -121,6 +121,12 @@ class MetaInfo {
* \param num Number of elements in the source array. * \param num Number of elements in the source array.
*/ */
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num); void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
/*!
* \brief Set information in the meta info with array interface.
* \param key The key of the information.
* \param interface_str String representation of json format array interface.
*/
void SetInfo(const char* key, std::string const& interface_str);
private: private:
/*! \brief argsort of labels */ /*! \brief argsort of labels */

View File

@ -4,8 +4,6 @@
#ifndef XGBOOST_JSON_H_ #ifndef XGBOOST_JSON_H_
#define XGBOOST_JSON_H_ #define XGBOOST_JSON_H_
#include <dmlc/io.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <string> #include <string>

View File

@ -1,5 +1,5 @@
/*! /*!
* Copyright (c) 2015 by Contributors * Copyright (c) 2015-2019 by Contributors
* \file logging.h * \file logging.h
* \brief defines console logging options for xgboost. * \brief defines console logging options for xgboost.
* Use to enforce unified print behavior. * Use to enforce unified print behavior.
@ -11,12 +11,13 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <xgboost/base.h>
#include <sstream> #include <sstream>
#include <map> #include <map>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "./base.h"
namespace xgboost { namespace xgboost {
@ -154,5 +155,14 @@ using LogCallbackRegistryStore = dmlc::ThreadLocalStore<LogCallbackRegistry>;
::xgboost::ConsoleLogger::LogVerbosity::kIgnore) ::xgboost::ConsoleLogger::LogVerbosity::kIgnore)
// Enable LOG(TRACKER) for print messages to tracker // Enable LOG(TRACKER) for print messages to tracker
#define LOG_TRACKER ::xgboost::TrackerLogger() #define LOG_TRACKER ::xgboost::TrackerLogger()
#if defined(CHECK)
#undef CHECK
#define CHECK(cond) \
if (XGBOOST_EXPECT(!(cond), false)) \
dmlc::LogMessageFatal(__FILE__, __LINE__).stream() \
<< "Check failed: " #cond << ": "
#endif // defined(CHECK)
} // namespace xgboost. } // namespace xgboost.
#endif // XGBOOST_LOGGING_H_ #endif // XGBOOST_LOGGING_H_

View File

@ -124,9 +124,16 @@ except ImportError:
class DataTable(object): class DataTable(object):
""" dummy for datatable.DataTable """ """ dummy for datatable.DataTable """
DT_INSTALLED = False DT_INSTALLED = False
try:
from cudf import DataFrame as CUDF_DataFrame
CUDF_INSTALLED = True
except ImportError:
CUDF_DataFrame = object
CUDF_INSTALLED = False
# sklearn # sklearn
try: try:
from sklearn.base import BaseEstimator from sklearn.base import BaseEstimator

View File

@ -1,26 +1,27 @@
# coding: utf-8 # coding: utf-8
# pylint: disable=too-many-arguments, too-many-branches, invalid-name # pylint: disable=too-many-arguments, too-many-branches, invalid-name
# pylint: disable=too-many-branches, too-many-lines, too-many-locals # pylint: disable=too-many-branches, too-many-lines, too-many-locals
# pylint: disable=too-many-public-methods
"""Core XGBoost Library.""" """Core XGBoost Library."""
from __future__ import absolute_import
import collections import collections
# pylint: disable=no-name-in-module,import-error # pylint: disable=no-name-in-module,import-error
try: from collections.abc import Mapping # Python 3
from collections.abc import Mapping # Python 3
except ImportError:
from collections import Mapping # Python 2
# pylint: enable=no-name-in-module,import-error # pylint: enable=no-name-in-module,import-error
import math
import ctypes import ctypes
import os import os
import re import re
import sys import sys
import warnings import warnings
import json
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str, from .compat import (STRING_TYPES, PY3, DataFrame, MultiIndex, py_str,
PANDAS_INSTALLED, DataTable, os_fspath, os_PathLike) PANDAS_INSTALLED, DataTable,
CUDF_INSTALLED, CUDF_DataFrame,
os_fspath, os_PathLike)
from .libpath import find_lib_path from .libpath import find_lib_path
@ -131,8 +132,10 @@ def _load_lib():
os_error_list = [] os_error_list = []
for lib_path in lib_paths: for lib_path in lib_paths:
try: try:
# needed when the lib is linked with non-system-available dependencies # needed when the lib is linked with non-system-available
os.environ['PATH'] = os.pathsep.join(pathBackup + [os.path.dirname(lib_path)]) # dependencies
os.environ['PATH'] = os.pathsep.join(
pathBackup + [os.path.dirname(lib_path)])
lib = ctypes.cdll.LoadLibrary(lib_path) lib = ctypes.cdll.LoadLibrary(lib_path)
lib_success = True lib_success = True
except OSError as e: except OSError as e:
@ -217,6 +220,51 @@ def c_array(ctype, values):
return (ctype * len(values))(*values) return (ctype * len(values))(*values)
def _use_columnar_initializer(data):
'''Whether should we use columnar format initializer (pass data in as
json string). Currently cudf is the only valid option.'''
if CUDF_INSTALLED and isinstance(data, CUDF_DataFrame):
return True
return False
def _extract_interface_from_cudf(df, is_info):
'''This function should be upstreamed to cudf.'''
if not _use_columnar_initializer(df):
raise ValueError('Only cudf is supported for initializing as json ' +
'columnar format. For other libraries please ' +
'refer to specific API.')
def get_interface(obj):
return obj.mem.__cuda_array_interface__
array_interfaces = []
for col in df.columns:
data = df[col].data
array_interfaces.append(get_interface(data))
validity_masks = []
for col in df.columns:
if df[col].has_null_mask:
mask_interface = get_interface(df[col].nullmask)
mask_interface['null_count'] = df[col].null_count
validity_masks.append(mask_interface)
else:
validity_masks.append(False)
for i in range(len(df.columns)):
col_interface = array_interfaces[i]
mask_interface = validity_masks[i]
if mask_interface is not False:
col_interface['mask'] = mask_interface
if is_info:
array_interfaces = array_interfaces[0]
interfaces = bytes(json.dumps(array_interfaces, indent=2), 'utf-8')
return interfaces
PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int', PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int', 'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float', 'float16': 'float', 'float32': 'float', 'float64': 'float',
@ -256,15 +304,18 @@ def _maybe_pandas_data(data, feature_names, feature_types):
def _maybe_pandas_label(label): def _maybe_pandas_label(label):
""" Extract internal data from pd.DataFrame for DMatrix label """ """Extract internal data from pd.DataFrame for DMatrix label."""
if PANDAS_INSTALLED and isinstance(label, DataFrame): if PANDAS_INSTALLED and isinstance(label, DataFrame):
if len(label.columns) > 1: if len(label.columns) > 1:
raise ValueError('DataFrame for label cannot have multiple columns') raise ValueError(
'DataFrame for label cannot have multiple columns')
label_dtypes = label.dtypes label_dtypes = label.dtypes
if not all(dtype.name in PANDAS_DTYPE_MAPPER for dtype in label_dtypes): if not all(dtype.name in PANDAS_DTYPE_MAPPER
raise ValueError('DataFrame.dtypes for label must be int, float or bool') for dtype in label_dtypes):
raise ValueError(
'DataFrame.dtypes for label must be int, float or bool')
label = label.values.astype('float') label = label.values.astype('float')
# pd.Series can be passed to xgb as it is # pd.Series can be passed to xgb as it is
@ -318,6 +369,22 @@ def _maybe_dt_array(array):
return array return array
def _check_data(data, missing):
'''The missing value applies only to np.ndarray.'''
is_invalid = (not isinstance(data, np.ndarray)) and (missing is not None)
is_invalid = is_invalid and not math.isnan(missing)
if is_invalid:
raise ValueError(
'missing value only applies to dense input, ' +
'e.g. `numpy.ndarray`.' +
' For a possibly sparse data type: ' + str(type(data)) +
' please remove missing values or set it to nan.' +
' Current missing value is set to: ' + str(missing))
if isinstance(data, list):
warnings.warn('Initializing DMatrix from List is deprecated.',
DeprecationWarning)
class DMatrix(object): class DMatrix(object):
"""Data Matrix used in XGBoost. """Data Matrix used in XGBoost.
@ -336,15 +403,16 @@ class DMatrix(object):
""" """
Parameters Parameters
---------- ----------
data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/dt.Frame data : os.PathLike/string/numpy.array/scipy.sparse/pd.DataFrame/
dt.Frame/cudf.DataFrame
Data source of DMatrix. Data source of DMatrix.
When data is string or os.PathLike type, it represents the path libsvm format When data is string or os.PathLike type, it represents the path libsvm format
txt file, or binary file that xgboost can read from. txt file, or binary file that xgboost can read from.
label : list or numpy 1-D array, optional label : list or numpy 1-D array, optional
Label of the training data. Label of the training data.
missing : float, optional missing : float, optional
Value in the data which needs to be present as a missing value. If Value in the dense input data (e.g. `numpy.ndarray`) which needs
None, defaults to np.nan. to be present as a missing value. If None, defaults to np.nan.
weight : list or numpy 1-D array , optional weight : list or numpy 1-D array , optional
Weight for each instance. Weight for each instance.
@ -375,6 +443,8 @@ class DMatrix(object):
self._feature_types = feature_types self._feature_types = feature_types
return return
_check_data(data, missing)
data, feature_names, feature_types = _maybe_pandas_data(data, data, feature_names, feature_types = _maybe_pandas_data(data,
feature_names, feature_names,
feature_types) feature_types)
@ -382,14 +452,11 @@ class DMatrix(object):
data, feature_names, feature_types = _maybe_dt_data(data, data, feature_names, feature_types = _maybe_dt_data(data,
feature_names, feature_names,
feature_types) feature_types)
label = _maybe_pandas_label(label) label = _maybe_pandas_label(label)
label = _maybe_dt_array(label) label = _maybe_dt_array(label)
weight = _maybe_dt_array(weight) weight = _maybe_dt_array(weight)
if isinstance(data, list):
warnings.warn('Initializing DMatrix from List is deprecated.',
DeprecationWarning)
if isinstance(data, (STRING_TYPES, os_PathLike)): if isinstance(data, (STRING_TYPES, os_PathLike)):
handle = ctypes.c_void_p() handle = ctypes.c_void_p()
_check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)), _check_call(_LIB.XGDMatrixCreateFromFile(c_str(os_fspath(data)),
@ -404,6 +471,8 @@ class DMatrix(object):
self._init_from_npy2d(data, missing, nthread) self._init_from_npy2d(data, missing, nthread)
elif isinstance(data, DataTable): elif isinstance(data, DataTable):
self._init_from_dt(data, nthread) self._init_from_dt(data, nthread)
elif _use_columnar_initializer(data):
self._init_from_columnar(data)
else: else:
try: try:
csr = scipy.sparse.csr_matrix(data) csr = scipy.sparse.csr_matrix(data)
@ -415,11 +484,15 @@ class DMatrix(object):
if label is not None: if label is not None:
if isinstance(label, np.ndarray): if isinstance(label, np.ndarray):
self.set_label_npy2d(label) self.set_label_npy2d(label)
elif _use_columnar_initializer(label):
self.set_interface_info('label', label)
else: else:
self.set_label(label) self.set_label(label)
if weight is not None: if weight is not None:
if isinstance(weight, np.ndarray): if isinstance(weight, np.ndarray):
self.set_weight_npy2d(weight) self.set_weight_npy2d(weight)
elif _use_columnar_initializer(label):
self.set_interface_info('weight', weight)
else: else:
self.set_weight(weight) self.set_weight(weight)
@ -526,8 +599,19 @@ class DMatrix(object):
nthread)) nthread))
self.handle = handle self.handle = handle
def _init_from_columnar(self, df):
'''Initialize DMatrix from columnar memory format.
'''
interfaces = _extract_interface_from_cudf(df, False)
handle = ctypes.c_void_p()
_check_call(
_LIB.XGDMatrixCreateFromArrayInterfaces(interfaces,
ctypes.byref(handle)))
self.handle = handle
def __del__(self): def __del__(self):
if hasattr(self, "handle") and self.handle is not None: if hasattr(self, "handle") and self.handle:
_check_call(_LIB.XGDMatrixFree(self.handle)) _check_call(_LIB.XGDMatrixFree(self.handle))
self.handle = None self.handle = None
@ -593,6 +677,13 @@ class DMatrix(object):
c_data, c_data,
c_bst_ulong(len(data)))) c_bst_ulong(len(data))))
def set_interface_info(self, field, data):
'''Set info type peoperty into DMatrix.'''
interfaces = _extract_interface_from_cudf(data, True)
_check_call(_LIB.XGDMatrixSetInfoFromInterface(self.handle,
c_str(field),
interfaces))
def set_float_info_npy2d(self, field, data): def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix """Set float type property into the DMatrix
for numpy 2d array input for numpy 2d array input
@ -732,7 +823,10 @@ class DMatrix(object):
margin: array like margin: array like
Prediction margin of each datapoint Prediction margin of each datapoint
""" """
self.set_float_info('base_margin', margin) if _use_columnar_initializer(margin):
self.set_interface_info('base_margin', margin)
else:
self.set_float_info('base_margin', margin)
def set_group(self, group): def set_group(self, group):
"""Set group size of DMatrix (used for ranking). """Set group size of DMatrix (used for ranking).
@ -742,9 +836,12 @@ class DMatrix(object):
group : array like group : array like
Group size of each group Group size of each group
""" """
_check_call(_LIB.XGDMatrixSetGroup(self.handle, if _use_columnar_initializer(group):
c_array(ctypes.c_uint, group), self.set_interface_info('group', group)
c_bst_ulong(len(group)))) else:
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
def get_label(self): def get_label(self):
"""Get the label of the DMatrix. """Get the label of the DMatrix.
@ -831,7 +928,8 @@ class DMatrix(object):
feature_names : list or None feature_names : list or None
""" """
if self._feature_names is None: if self._feature_names is None:
self._feature_names = ['f{0}'.format(i) for i in range(self.num_col())] self._feature_names = ['f{0}'.format(i)
for i in range(self.num_col())]
return self._feature_names return self._feature_names
@property @property

View File

@ -1,4 +1,4 @@
// Copyright (c) 2014 by Contributors // Copyright (c) 2014-2019 by Contributors
#include <xgboost/data.h> #include <xgboost/data.h>
#include <xgboost/learner.h> #include <xgboost/learner.h>
@ -16,7 +16,7 @@
#include <string> #include <string>
#include <memory> #include <memory>
#include "./c_api_error.h" #include "c_api_error.h"
#include "../data/simple_csr_source.h" #include "../data/simple_csr_source.h"
#include "../common/math.h" #include "../common/math.h"
#include "../common/io.h" #include "../common/io.h"
@ -189,6 +189,16 @@ int XGDMatrixCreateFromDataIter(
API_END(); API_END();
} }
XGB_DLL int XGDMatrixCreateFromArrayInterfaces(
char const* c_json_strs, DMatrixHandle* out) {
API_BEGIN();
std::string json_str {c_json_strs};
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(json_str);
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(std::move(source)));
API_END();
}
XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr, XGB_DLL int XGDMatrixCreateFromCSREx(const size_t* indptr,
const unsigned* indices, const unsigned* indices,
const bst_float* data, const bst_float* data,
@ -679,9 +689,9 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
} }
XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle, XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field, const char* field,
const bst_float* info, const bst_float* info,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) static_cast<std::shared_ptr<DMatrix>*>(handle)
@ -689,10 +699,20 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
API_END(); API_END();
} }
XGB_DLL int XGDMatrixSetInfoFromInterface(DMatrixHandle handle,
char const* field,
char const* interface_c_str) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, interface_c_str);
API_END();
}
XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle, XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field, const char* field,
const unsigned* info, const unsigned* info,
xgboost::bst_ulong len) { xgboost::bst_ulong len) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle) static_cast<std::shared_ptr<DMatrix>*>(handle)
@ -771,7 +791,7 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,
xgboost::bst_ulong *out) { xgboost::bst_ulong *out) {
API_BEGIN(); API_BEGIN();
CHECK_HANDLE(); CHECK_HANDLE();
*out = static_cast<size_t>( *out = static_cast<xgboost::bst_ulong>(
static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_); static_cast<std::shared_ptr<DMatrix>*>(handle)->get()->Info().num_col_);
API_END(); API_END();
} }

View File

@ -1,160 +0,0 @@
#ifndef XGBOOST_COMMON_BITFIELD_CUH_
#define XGBOOST_COMMON_BITFIELD_CUH_
#include <bitset>
#include <string>
#include <iostream>
#include <sstream>
#include <cinttypes>
#include "span.h"
namespace xgboost {
__forceinline__ __device__ unsigned long long AtomicOr(unsigned long long* address,
unsigned long long val) {
unsigned long long int old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val | assumed);
} while (assumed != old);
return old;
}
__forceinline__ __device__ unsigned long long AtomicAnd(unsigned long long* address,
unsigned long long val) {
unsigned long long int old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val & assumed);
} while (assumed != old);
return old;
}
/*!
* \brief A non-owning type with auxiliary methods defined for manipulating bits.
*/
struct BitField {
using value_type = uint64_t;
static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1UL; // force uint64_t
static_assert(kValueSize == 64, "uint64_t should be of 64 bits.");
struct Pos {
value_type int_pos {0};
value_type bit_pos {0};
};
common::Span<value_type> bits_;
public:
BitField() = default;
XGBOOST_DEVICE BitField(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitField(BitField const& other) : bits_{other.bits_} {}
static size_t ComputeStorageSize(size_t size) {
auto pos = ToBitPos(size);
if (size < kValueSize) {
return 1;
}
if (pos.bit_pos != 0) {
return pos.int_pos + 2;
} else {
return pos.int_pos + 1;
}
}
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
Pos pos_v;
if (pos == 0) {
return pos_v;
}
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
return pos_v;
}
__device__ BitField& operator|=(BitField const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(bits_.size(), rhs.bits_.size());
if (tid < min_size) {
bits_[tid] |= rhs.bits_[tid];
}
return *this;
}
__device__ BitField& operator&=(BitField const& rhs) {
size_t min_size = min(bits_.size(), rhs.bits_.size());
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < min_size) {
bits_[tid] &= rhs.bits_[tid];
}
return *this;
}
XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); }
__device__ void Set(value_type pos) {
Pos pos_v = ToBitPos(pos);
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << (kValueSize - pos_v.bit_pos - kOne);
static_assert(sizeof(unsigned long long int) == sizeof(value_type), "");
AtomicOr(reinterpret_cast<unsigned long long*>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
Pos pos_v = ToBitPos(pos);
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << (kValueSize - pos_v.bit_pos - kOne));
static_assert(sizeof(unsigned long long int) == sizeof(value_type), "");
AtomicAnd(reinterpret_cast<unsigned long long*>(&value), clear_bit);
}
XGBOOST_DEVICE bool Check(Pos pos_v) const {
value_type value = bits_[pos_v.int_pos];
value_type const test_bit = kOne << (kValueSize - pos_v.bit_pos - kOne);
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(value_type pos) const {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
friend std::ostream& operator<<(std::ostream& os, BitField field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (size_t i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitField::kValueSize> set(field.bits_[i]);
os << set << "\n";
}
return os;
}
};
inline void PrintDeviceBits(std::string name, BitField field) {
std::cout << "Bits: " << name << std::endl;
std::vector<BitField::value_type> h_field_bits(field.bits_.size());
thrust::copy(thrust::device_ptr<BitField::value_type>(field.bits_.data()),
thrust::device_ptr<BitField::value_type>(field.bits_.data() + field.bits_.size()),
h_field_bits.data());
BitField h_field;
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
std::cout << h_field;
}
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
std::cout << name << std::endl;
std::vector<int32_t> h_list(list.size());
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
thrust::device_ptr<int32_t>(list.data() + list.size()),
h_list.data());
for (auto v : h_list) {
std::cout << v << ", ";
}
std::cout << std::endl;
}
}
#endif // XGBOOST_COMMON_BITFIELD_CUH_

248
src/common/bitfield.h Normal file
View File

@ -0,0 +1,248 @@
/*!
* Copyright 2019 by Contributors
* \file bitfield.h
*/
#ifndef XGBOOST_COMMON_BITFIELD_H_
#define XGBOOST_COMMON_BITFIELD_H_
#include <algorithm>
#include <bitset>
#include <cinttypes>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>
#include "span.h"
namespace xgboost {
#if defined(__CUDACC__)
using BitFieldAtomicType = unsigned long long; // NOLINT
__forceinline__ __device__ BitFieldAtomicType AtomicOr(BitFieldAtomicType* address,
BitFieldAtomicType val) {
BitFieldAtomicType old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val | assumed);
} while (assumed != old);
return old;
}
__forceinline__ __device__ BitFieldAtomicType AtomicAnd(BitFieldAtomicType* address,
BitFieldAtomicType val) {
BitFieldAtomicType old = *address, assumed; // NOLINT
do {
assumed = old;
old = atomicCAS(address, assumed, val & assumed);
} while (assumed != old);
return old;
}
#endif // defined(__CUDACC__)
/*!
* \brief A non-owning type with auxiliary methods defined for manipulating bits.
*
* \tparam Direction Whether the bits start from left or from right.
*/
template <typename VT, typename Direction>
struct BitFieldContainer {
using value_type = VT;
using pointer = value_type*;
static value_type constexpr kValueSize = sizeof(value_type) * 8;
static value_type constexpr kOne = 1; // force correct type.
struct Pos {
value_type int_pos {0};
value_type bit_pos {0};
};
common::Span<value_type> bits_;
static_assert(!std::is_signed<VT>::value, "Must use unsiged type as underlying storage.");
XGBOOST_DEVICE static Pos ToBitPos(value_type pos) {
Pos pos_v;
if (pos == 0) {
return pos_v;
}
pos_v.int_pos = pos / kValueSize;
pos_v.bit_pos = pos % kValueSize;
return pos_v;
}
public:
BitFieldContainer() = default;
XGBOOST_DEVICE BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
static size_t ComputeStorageSize(size_t size) {
auto pos = ToBitPos(size);
if (size < kValueSize) {
return 1;
}
if (pos.bit_pos != 0) {
return pos.int_pos + 2;
} else {
return pos.int_pos + 1;
}
}
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
size_t min_size = min(bits_.size(), rhs.bits_.size());
if (tid < min_size) {
bits_[tid] |= rhs.bits_[tid];
}
return *this;
}
#else
BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
size_t min_size = std::min(bits_.size(), rhs.bits_.size());
for (size_t i = 0; i < min_size; ++i) {
bits_[i] |= rhs.bits_[i];
}
return *this;
}
#endif // #if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = min(bits_.size(), rhs.bits_.size());
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < min_size) {
bits_[tid] &= rhs.bits_[tid];
}
return *this;
}
#else
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
size_t min_size = std::min(bits_.size(), rhs.bits_.size());
for (size_t i = 0; i < min_size; ++i) {
bits_[i] &= rhs.bits_[i];
}
return *this;
}
#endif // defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
__device__ void Set(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicOr(reinterpret_cast<BitFieldAtomicType*>(&value), set_bit);
}
__device__ void Clear(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
static_assert(sizeof(BitFieldAtomicType) == sizeof(value_type), "");
AtomicAnd(reinterpret_cast<BitFieldAtomicType*>(&value), clear_bit);
}
#else
void Set(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type set_bit = kOne << pos_v.bit_pos;
value |= set_bit;
}
void Clear(value_type pos) {
Pos pos_v = Direction::Shift(ToBitPos(pos));
value_type& value = bits_[pos_v.int_pos];
value_type clear_bit = ~(kOne << pos_v.bit_pos);
value &= clear_bit;
}
#endif // defined(__CUDA_ARCH__)
XGBOOST_DEVICE bool Check(Pos pos_v) const {
pos_v = Direction::Shift(pos_v);
value_type const value = bits_[pos_v.int_pos];
value_type const test_bit = kOne << pos_v.bit_pos;
value_type result = test_bit & value;
return static_cast<bool>(result);
}
XGBOOST_DEVICE bool Check(value_type pos) const {
Pos pos_v = ToBitPos(pos);
return Check(pos_v);
}
XGBOOST_DEVICE size_t Size() const { return kValueSize * bits_.size(); }
XGBOOST_DEVICE pointer Data() const { return bits_.data(); }
friend std::ostream& operator<<(std::ostream& os, BitFieldContainer<VT, Direction> field) {
os << "Bits " << "storage size: " << field.bits_.size() << "\n";
for (typename common::Span<value_type>::index_type i = 0; i < field.bits_.size(); ++i) {
std::bitset<BitFieldContainer<VT, Direction>::kValueSize> bset(field.bits_[i]);
os << bset << "\n";
}
return os;
}
};
// Bits start from left most bits (most significant bit).
template <typename VT>
struct LBitsPolicy : public BitFieldContainer<VT, LBitsPolicy<VT>> {
using Container = BitFieldContainer<VT, LBitsPolicy<VT>>;
using Pos = typename Container::Pos;
using value_type = typename Container::value_type;
XGBOOST_DEVICE static Pos Shift(Pos pos) {
pos.bit_pos = Container::kValueSize - pos.bit_pos - Container::kOne;
return pos;
}
};
// Bits start from right most bit (least significant bit) of each entry, but integer index
// is from left to right.
template <typename VT>
struct RBitsPolicy : public BitFieldContainer<VT, RBitsPolicy<VT>> {
using Container = BitFieldContainer<VT, RBitsPolicy<VT>>;
using Pos = typename Container::Pos;
using value_type = typename Container::value_type;
XGBOOST_DEVICE static Pos Shift(Pos pos) {
return pos;
}
};
// Format: <Direction>BitField<size of underlying type>, underlying type must be unsigned.
using LBitField64 = BitFieldContainer<uint64_t, LBitsPolicy<uint64_t>>;
using RBitField8 = BitFieldContainer<unsigned char, RBitsPolicy<unsigned char>>;
#if defined(__CUDACC__)
template <typename V, typename D>
inline void PrintDeviceBits(std::string name, BitFieldContainer<V, D> field) {
std::cout << "Bits: " << name << std::endl;
std::vector<typename BitFieldContainer<V, D>::value_type> h_field_bits(field.bits_.size());
thrust::copy(thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(field.bits_.data()),
thrust::device_ptr<typename BitFieldContainer<V, D>::value_type>(
field.bits_.data() + field.bits_.size()),
h_field_bits.data());
BitFieldContainer<V, D> h_field;
h_field.bits_ = {h_field_bits.data(), h_field_bits.data() + h_field_bits.size()};
std::cout << h_field;
}
inline void PrintDeviceStorage(std::string name, common::Span<int32_t> list) {
std::cout << name << std::endl;
std::vector<int32_t> h_list(list.size());
thrust::copy(thrust::device_ptr<int32_t>(list.data()),
thrust::device_ptr<int32_t>(list.data() + list.size()),
h_list.data());
for (auto v : h_list) {
std::cout << v << ", ";
}
std::cout << std::endl;
}
#endif // defined(__CUDACC__)
} // namespace xgboost
#endif // XGBOOST_COMMON_BITFIELD_H_

View File

@ -51,12 +51,18 @@ inline ncclResult_t ThrowOnNcclError(ncclResult_t code, const char *file,
} }
#endif #endif
inline void CudaCheckPointerDevice(void* ptr) { inline int32_t CudaGetPointerDevice(void* ptr) {
int32_t device = -1;
cudaPointerAttributes attr; cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr)); dh::safe_cuda(cudaPointerGetAttributes(&attr, ptr));
int ptr_device = attr.device; device = attr.device;
return device;
}
inline void CudaCheckPointerDevice(void* ptr) {
auto ptr_device = CudaGetPointerDevice(ptr);
int cur_device = -1; int cur_device = -1;
cudaGetDevice(&cur_device); dh::safe_cuda(cudaGetDevice(&cur_device));
CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device CHECK_EQ(ptr_device, cur_device) << "pointer device: " << ptr_device
<< "current device: " << cur_device; << "current device: " << cur_device;
} }

View File

@ -365,8 +365,9 @@ struct HostDeviceVectorImpl {
void Shard(const GPUDistribution& distribution) { void Shard(const GPUDistribution& distribution) {
if (distribution_ == distribution) { return; } if (distribution_ == distribution) { return; }
CHECK(distribution_.IsEmpty()) CHECK(distribution_.IsEmpty())
<< "This: " << distribution_.Devices().Size() << ", " << "Data resides on different GPUs: " << "ID: "
<< "Others: " << distribution.Devices().Size(); << *(distribution_.Devices().begin()) << " and ID: "
<< *(distribution.Devices().begin());
distribution_ = distribution; distribution_ = distribution;
InitShards(); InitShards();
} }
@ -412,6 +413,7 @@ struct HostDeviceVectorImpl {
perm_h_.Grant(access); perm_h_.Grant(access);
return; return;
} }
std::lock_guard<std::mutex> lock(mutex_);
if (data_h_.size() != size_d_) { data_h_.resize(size_d_); } if (data_h_.size() != size_d_) { data_h_.resize(size_d_); }
dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) { dh::ExecuteIndexShards(&shards_, [&](int idx, DeviceShard& shard) {
shard.LazySyncHost(access); shard.LazySyncHost(access);

82
src/data/columnar.h Normal file
View File

@ -0,0 +1,82 @@
/*!
* Copyright 2019 by Contributors
* \file columnar.h
* \brief Basic structure holding a reference to arrow columnar data format.
*/
#ifndef XGBOOST_DATA_COLUMNAR_H_
#define XGBOOST_DATA_COLUMNAR_H_
#include <cinttypes>
#include <map>
#include <string>
#include "xgboost/data.h"
#include "xgboost/json.h"
#include "../common/span.h"
#include "../common/bitfield.h"
namespace xgboost {
struct Columnar {
using mask_type = unsigned char;
using index_type = int32_t;
common::Span<float> data;
RBitField8 valid;
int32_t size;
int32_t null_count;
};
// Common errors in parsing columnar format.
struct ColumnarErrors {
static char const* Contigious() {
return "Memory should be contigious.";
}
static char const* TypestrFormat() {
return "`typestr` should be of format <endian><type><size>.";
}
// Not supported in Apache Arrow.
static char const* BigEndian() {
return "Big endian is not supported.";
}
static char const* Dimension(int32_t d) {
static std::string str;
str.clear();
str += "Only ";
str += std::to_string(d);
str += " dimensional array is valid.";
return str.c_str();
}
static char const* Version() {
return "Only version 1 of __cuda_array_interface__ is being supported.";
}
static char const* toFloat() {
return "Please convert the input into float32 first.";
}
static char const* toUInt() {
return "Please convert the Group into unsigned 32 bit integers first.";
}
static char const* ofType(std::string type) {
static std::string str;
str.clear();
str += " should be of ";
str += type;
str += " type.";
return str.c_str();
}
};
template <typename PtrType>
PtrType GetPtrFromArrayData(std::map<std::string, Json> const& obj) {
if (obj.find("data") == obj.cend()) {
LOG(FATAL) << "Empty data passed in.";
}
auto p_data = reinterpret_cast<PtrType>(static_cast<size_t>(
get<Integer const>(
get<Array const>(
obj.at("data"))
.at(0))));
return p_data;
}
} // namespace xgboost
#endif // XGBOOST_DATA_COLUMNAR_H_

View File

@ -1,11 +1,12 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2019 by Contributors
* \file data.cc * \file data.cc
*/ */
#include <xgboost/data.h> #include <xgboost/data.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <dmlc/registry.h> #include <dmlc/registry.h>
#include <cstring> #include <cstring>
#include "./sparse_page_writer.h" #include "./sparse_page_writer.h"
#include "./simple_dmatrix.h" #include "./simple_dmatrix.h"
#include "./simple_csr_source.h" #include "./simple_csr_source.h"
@ -110,7 +111,6 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
default: LOG(FATAL) << "Unknown data type" << dtype; \ default: LOG(FATAL) << "Unknown data type" << dtype; \
} \ } \
void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) { void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
if (!std::strcmp(key, "root_index")) { if (!std::strcmp(key, "root_index")) {
root_index_.resize(num); root_index_.resize(num);
@ -139,9 +139,17 @@ void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t
for (size_t i = 1; i < group_ptr_.size(); ++i) { for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i]; group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
} }
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
} }
} }
#if !defined(XGBOOST_USE_CUDA)
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
LOG(FATAL) << "XGBoost version is not compiled with GPU support";
}
#endif // !defined(XGBOOST_USE_CUDA)
DMatrix* DMatrix::Load(const std::string& uri, DMatrix* DMatrix::Load(const std::string& uri,
bool silent, bool silent,
bool load_row_split, bool load_row_split,

86
src/data/data.cu Normal file
View File

@ -0,0 +1,86 @@
/*!
* Copyright 2019 by XGBoost Contributors
*
* \file data.cu
*/
#include "xgboost/data.h"
#include "xgboost/logging.h"
#include "xgboost/json.h"
#include "columnar.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
Json j_arr = Json::Load({interface_str.c_str(), interface_str.size()});
auto const& j_arr_obj = get<Object>(j_arr);
std::string key {c_key};
auto version = get<Integer const>(j_arr_obj.at("version"));
CHECK_EQ(version, 1) << ColumnarErrors::Version();
if (j_arr_obj.find("mask") != j_arr_obj.cend()) {
LOG(FATAL) << "Meta info " << key << " should be dense, found validity mask";
}
auto typestr = get<String const>(j_arr_obj.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
auto j_shape = get<Array const>(j_arr_obj.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
auto length = get<Integer const>(j_shape.at(0));
CHECK_GT(length, 0) << "Label set cannot be empty.";
if (j_arr_obj.find("strides") != j_arr_obj.cend()) {
auto strides = get<Array const>(j_arr_obj.at("strides"));
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
}
float* p_data = GetPtrFromArrayData<float*>(j_arr_obj);
cudaPointerAttributes attr;
dh::safe_cuda(cudaPointerGetAttributes(&attr, p_data));
int32_t ptr_device = attr.device;
dh::safe_cuda(cudaSetDevice(ptr_device));
thrust::device_ptr<float> p_src {p_data};
HostDeviceVector<float>* dst;
if (key == "root_index") {
LOG(FATAL) << "root index for columnar data is not supported.";
} else if (key == "label") {
dst = &labels_;
CHECK_EQ(typestr.at(1), 'f') << "Label"
<< ColumnarErrors::ofType("floating point");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
} else if (key == "weight") {
dst = &weights_;
CHECK_EQ(typestr.at(1), 'f') << "Weight"
<< ColumnarErrors::ofType("floating point");;
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
} else if (key == "base_margin") {
dst = &base_margin_;
CHECK_EQ(typestr.at(1), 'f') << "Base Margin"
<< ColumnarErrors::ofType("floating point");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
} else if (key == "group") {
CHECK_EQ(typestr.at(1), 'u') << "Group"
<< ColumnarErrors::ofType("unsigned 32 bit integers");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toUInt();
group_ptr_.resize(length + 1);
group_ptr_[0] = 0;
// Ranking is not performed on device.
thrust::copy(p_src, p_src + length, group_ptr_.begin() + 1);
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
return;
} else {
LOG(FATAL) << "Unknown metainfo: " << key;
}
dst->Reshard(GPUDistribution(GPUSet::Range(ptr_device, 1)));
dst->Resize(length);
auto p_dst = thrust::device_pointer_cast(dst->DevicePointer(0));
thrust::copy(p_src, p_src + length, p_dst);
}
} // namespace xgboost

View File

@ -1,11 +1,14 @@
/*! /*!
* Copyright 2015 by Contributors * Copyright 2015-2019 by Contributors
* \file simple_csr_source.cc * \file simple_csr_source.cc
*/ */
#include <dmlc/base.h> #include <dmlc/base.h>
#include <xgboost/logging.h> #include <xgboost/logging.h>
#include <xgboost/json.h>
#include <limits> #include <limits>
#include "./simple_csr_source.h" #include "simple_csr_source.h"
#include "columnar.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {
@ -117,5 +120,143 @@ const SparsePage& SimpleCSRSource::Value() const {
return page_; return page_;
} }
/*!
* Please be careful that, in official specification, the only three required fields are
* `shape', `version' and `typestr'. Any other is optional, including `data'. But here
* we have two additional requirements for input data:
*
* - `data' field is required, passing in an empty dataset is not accepted, as most (if
* not all) of our algorithms don't have test for empty dataset. An error is better
* than a crash.
*
* - `null_count' is required when `mask' is presented. We can compute `null_count'
* ourselves and copy the result back to host for memory allocation. But it's in the
* specification of Apache Arrow hence it should be readily available,
*
* Sample input:
* [
* {
* "shape": [
* 10
* ],
* "strides": [
* 4
* ],
* "data": [
* 30074864128,
* false
* ],
* "typestr": "<f4",
* "version": 1,
* "mask": {
* "shape": [
* 64
* ],
* "strides": [
* 1
* ],
* "data": [
* 30074864640,
* false
* ],
* "typestr": "|i1",
* "version": 1,
* "null_count": 1
* }
* }
* ]
*/
void SimpleCSRSource::CopyFrom(std::string const& cuda_interfaces_str) {
Json interfaces = Json::Load({cuda_interfaces_str.c_str(),
cuda_interfaces_str.size()});
std::vector<Json> const& columns = get<Array>(interfaces);
size_t n_columns = columns.size();
CHECK_GT(n_columns, 0);
std::vector<Columnar> foreign_cols(n_columns);
for (size_t i = 0; i < columns.size(); ++i) {
CHECK(IsA<Object>(columns[i]));
auto const& column = get<Object const>(columns[i]);
auto version = get<Integer const>(column.at("version"));
CHECK_EQ(version, 1) << ColumnarErrors::Version();
// Find null mask (validity mask) field
// Mask object is also an array interface, but with different requirements.
// TODO(trivialfis): Abstract this into a class that accept a json
// object and turn it into an array (for cupy and numba).
common::Span<RBitField8::value_type> s_mask;
int32_t null_count {0};
if (column.find("mask") != column.cend()) {
auto const& j_mask = get<Object const>(column.at("mask"));
auto p_mask = GetPtrFromArrayData<RBitField8::value_type*>(j_mask);
auto j_shape = get<Array const>(j_mask.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(j_shape.front()) % 8, 0) <<
"Length of validity map must be a multiple of 8 bytes.";
int64_t size = get<Integer>(j_shape.at(0)) *
sizeof(unsigned char) / sizeof(RBitField8::value_type);
s_mask = {p_mask, size};
auto typestr = get<String const>(j_mask.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
CHECK_EQ(typestr.at(1), 'i') << "mask" << ColumnarErrors::ofType("unsigned char");
CHECK_EQ(typestr.at(2), '1') << "mask" << ColumnarErrors::toUInt();
CHECK(j_mask.find("null_count") != j_mask.cend()) <<
"Column with null mask must include null_count as "
"part of mask object for XGBoost.";
null_count = get<Integer const>(j_mask.at("null_count"));
}
// Find data field
if (column.find("data") == column.cend()) {
LOG(FATAL) << "Empty dataset passed in.";
}
auto typestr = get<String const>(column.at("typestr"));
CHECK_EQ(typestr.size(), 3) << ColumnarErrors::TypestrFormat();
CHECK_NE(typestr.front(), '>') << ColumnarErrors::BigEndian();
CHECK_EQ(typestr.at(1), 'f') << "data" << ColumnarErrors::ofType("floating point");
CHECK_EQ(typestr.at(2), '4') << ColumnarErrors::toFloat();
auto j_shape = get<Array const>(column.at("shape"));
CHECK_EQ(j_shape.size(), 1) << ColumnarErrors::Dimension(1);
if (column.find("strides") != column.cend()) {
auto strides = get<Array const>(column.at("strides"));
CHECK_EQ(strides.size(), 1) << ColumnarErrors::Dimension(1);
CHECK_EQ(get<Integer>(strides.at(0)), 4) << ColumnarErrors::Contigious();
}
auto length = get<Integer const>(j_shape.at(0));
float* p_data = GetPtrFromArrayData<float*>(column);
common::Span<float> s_data {p_data, length};
foreign_cols[i].data = s_data;
foreign_cols[i].valid = RBitField8(s_mask);
foreign_cols[i].size = s_data.size();
foreign_cols[i].null_count = null_count;
}
info.num_col_ = n_columns;
info.num_row_ = foreign_cols[0].size;
for (size_t i = 0; i < n_columns; ++i) {
CHECK_EQ(foreign_cols[0].size, foreign_cols[i].size);
info.num_nonzero_ += foreign_cols[i].data.size() - foreign_cols[i].null_count;
}
this->FromDeviceColumnar(foreign_cols);
}
#if !defined(XGBOOST_USE_CUDA)
void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
LOG(FATAL) << "XGBoost version is not compiled with GPU support";
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace data } // namespace data
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,117 @@
/*!
* Copyright 2019 by XGBoost Contributors
*
* \file simple_csr_source.cuh
* \brief An extension for the simple CSR source in-memory data structure to accept
* foreign columnar.
*/
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/scan.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <vector>
#include <algorithm>
#include "simple_csr_source.h"
#include "columnar.h"
#include "../common/bitfield.h"
#include "../common/device_helpers.cuh"
namespace xgboost {
namespace data {
template <size_t kBlockThreads>
__global__ void CountValidKernel(common::Span<Columnar const> columns,
int32_t const n_rows,
common::Span<size_t> offsets) {
// One block for a column
auto const bid = blockIdx.x;
auto const tid = threadIdx.x;
if (bid >= columns.size()) {
return;
}
RBitField8 const mask = columns[bid].valid;
for (auto r = tid; r < n_rows; r += kBlockThreads) {
if (mask.Data() == nullptr || mask.Check(r)) {
atomicAdd(reinterpret_cast<BitFieldAtomicType*>(&offsets[r+1]),
static_cast<BitFieldAtomicType>(1));
}
}
}
__global__ void CreateCSRKernel(Columnar const column,
int32_t colid,
common::Span<size_t> offsets,
common::Span<Entry> out_data) {
auto tid = threadIdx.x + blockDim.x * blockIdx.x;
if (column.size <= tid) {
return;
}
if (column.valid.Data() == nullptr || column.valid.Check(tid)) {
int32_t oid = offsets[tid];
out_data[oid].fvalue = column.data[tid];
out_data[oid].index = colid;
offsets[tid] += 1;
}
}
void SimpleCSRSource::FromDeviceColumnar(std::vector<Columnar> cols) {
uint64_t const n_cols = cols.size();
uint64_t const n_rows = cols[0].size;
auto ptr = cols[0].data.data();
int32_t device = dh::CudaGetPointerDevice(ptr);
CHECK_NE(device, -1);
for (int32_t i = 1; i < n_cols; ++i) {
auto ptr = cols[i].data.data();
int32_t ptr_device = dh::CudaGetPointerDevice(ptr);
CHECK_EQ(device, ptr_device)
<< "GPU ID at 0^th column: " << device << ", "
<< "GPU ID at column " << i << ": " << ptr_device;
}
dh::safe_cuda(cudaSetDevice(device));
GPUSet devices = GPUSet::Range(device, 1);
page_.offset.Reshard(GPUDistribution(devices));
page_.offset.Resize(info.num_row_ + 1);
page_.data.Reshard(GPUDistribution(devices));
page_.data.Resize(info.num_nonzero_);
auto s_data = page_.data.DeviceSpan(device);
auto s_offsets = page_.offset.DeviceSpan(device);
CHECK_EQ(s_offsets.size(), n_rows + 1);
int32_t constexpr kThreads = 256;
dh::device_vector<Columnar> d_cols(cols);
auto s_d_cols = dh::ToSpan(d_cols);
dh::safe_cuda(cudaMemset(s_offsets.data(), 0, sizeof(int32_t) * (n_rows + 1)));
CountValidKernel<kThreads><<<n_cols, kThreads>>>(s_d_cols, n_rows, s_offsets);
thrust::device_ptr<size_t> p_offsets(s_offsets.data());
CHECK_GE(s_offsets.size(), n_rows + 1);
thrust::inclusive_scan(p_offsets, p_offsets + n_rows + 1, p_offsets);
// Created for building csr matrix, where we need to change index
// after processing each column.
dh::device_vector<size_t> tmp_offset(page_.offset.Size());
thrust::copy(p_offsets, p_offsets + n_rows + 1, tmp_offset.begin());
int32_t kBlocks = common::DivRoundUp(n_rows, kThreads);
for (size_t col = 0; col < n_cols; ++col) {
CreateCSRKernel<<<kBlocks, kThreads>>>(d_cols[col], col, dh::ToSpan(tmp_offset), s_data);
}
}
} // namespace data
} // namespace xgboost

View File

@ -10,9 +10,12 @@
#include <xgboost/base.h> #include <xgboost/base.h>
#include <xgboost/data.h> #include <xgboost/data.h>
#include <vector>
#include <algorithm>
#include <algorithm>
#include <string>
#include <vector>
#include "columnar.h"
namespace xgboost { namespace xgboost {
namespace data { namespace data {
@ -27,7 +30,6 @@ namespace data {
*/ */
class SimpleCSRSource : public DataSource<SparsePage> { class SimpleCSRSource : public DataSource<SparsePage> {
public: public:
// public data members
// MetaInfo info; // inheritated from DataSource // MetaInfo info; // inheritated from DataSource
SparsePage page_; SparsePage page_;
/*! \brief default constructor */ /*! \brief default constructor */
@ -47,6 +49,11 @@ class SimpleCSRSource : public DataSource<SparsePage> {
* \param info The additional information reflected in the parser. * \param info The additional information reflected in the parser.
*/ */
void CopyFrom(dmlc::Parser<uint32_t>* src); void CopyFrom(dmlc::Parser<uint32_t>* src);
/*!
* \brief copy content of data from foreign **GPU** columnar buffer.
* \param interfaces_str JSON representation of cuda array interfaces.
*/
void CopyFrom(std::string const& cuda_interfaces_str);
/*! /*!
* \brief Load data from binary stream. * \brief Load data from binary stream.
* \param fi the pointer to load data from. * \param fi the pointer to load data from.
@ -67,6 +74,11 @@ class SimpleCSRSource : public DataSource<SparsePage> {
static const int kMagic = 0xffffab01; static const int kMagic = 0xffffab01;
private: private:
/*!
* \brief copy content of data from foreign GPU columnar buffer.
* \param cols foreign columns data buffer.
*/
void FromDeviceColumnar(std::vector<Columnar> cols);
/*! \brief internal variable, used to support iterator interface */ /*! \brief internal variable, used to support iterator interface */
bool at_first_{true}; bool at_first_{true};
}; };

View File

@ -22,9 +22,6 @@
namespace xgboost { namespace xgboost {
BitField::value_type constexpr BitField::kValueSize;
BitField::value_type constexpr BitField::kOne;
size_t FeatureInteractionConstraint::Features() const { size_t FeatureInteractionConstraint::Features() const {
return d_sets_ptr_.size() - 1; return d_sets_ptr_.size() - 1;
} }
@ -51,7 +48,7 @@ void FeatureInteractionConstraint::Configure(
} }
n_sets_ = h_feature_constraints.size(); n_sets_ = h_feature_constraints.size();
size_t const n_feat_storage = BitField::ComputeStorageSize(n_features); size_t const n_feat_storage = LBitField64::ComputeStorageSize(n_features);
if (n_feat_storage == 0 && n_features != 0) { if (n_feat_storage == 0 && n_features != 0) {
LOG(FATAL) << "Wrong storage size, n_features: " << n_features; LOG(FATAL) << "Wrong storage size, n_features: " << n_features;
} }
@ -70,13 +67,13 @@ void FeatureInteractionConstraint::Configure(
node_constraints_.resize(n_nodes); node_constraints_.resize(n_nodes);
node_constraints_storage_.resize(n_nodes); node_constraints_storage_.resize(n_nodes);
for (auto& n : node_constraints_storage_) { for (auto& n : node_constraints_storage_) {
n.resize(BitField::ComputeStorageSize(n_features)); n.resize(LBitField64::ComputeStorageSize(n_features));
} }
for (size_t i = 0; i < node_constraints_storage_.size(); ++i) { for (size_t i = 0; i < node_constraints_storage_.size(); ++i) {
auto span = dh::ToSpan(node_constraints_storage_[i]); auto span = dh::ToSpan(node_constraints_storage_[i]);
node_constraints_[i] = BitField(span); node_constraints_[i] = LBitField64(span);
} }
s_node_constraints_ = common::Span<BitField>(node_constraints_.data(), s_node_constraints_ = common::Span<LBitField64>(node_constraints_.data(),
node_constraints_.size()); node_constraints_.size());
// Represent constraints as CSR format, flatten is the value vector, // Represent constraints as CSR format, flatten is the value vector,
@ -131,14 +128,14 @@ void FeatureInteractionConstraint::Configure(
s_sets_ = dh::ToSpan(d_sets_); s_sets_ = dh::ToSpan(d_sets_);
s_sets_ptr_ = dh::ToSpan(d_sets_ptr_); s_sets_ptr_ = dh::ToSpan(d_sets_ptr_);
d_feature_buffer_storage_.resize(BitField::ComputeStorageSize(n_features)); d_feature_buffer_storage_.resize(LBitField64::ComputeStorageSize(n_features));
feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_); feature_buffer_ = dh::ToSpan(d_feature_buffer_storage_);
// --- Initialize result buffers. // --- Initialize result buffers.
output_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features)); output_buffer_bits_storage_.resize(LBitField64::ComputeStorageSize(n_features));
output_buffer_bits_ = BitField(dh::ToSpan(output_buffer_bits_storage_)); output_buffer_bits_ = LBitField64(dh::ToSpan(output_buffer_bits_storage_));
input_buffer_bits_storage_.resize(BitField::ComputeStorageSize(n_features)); input_buffer_bits_storage_.resize(LBitField64::ComputeStorageSize(n_features));
input_buffer_bits_ = BitField(dh::ToSpan(input_buffer_bits_storage_)); input_buffer_bits_ = LBitField64(dh::ToSpan(input_buffer_bits_storage_));
result_buffer_.resize(n_features); result_buffer_.resize(n_features);
s_result_buffer_ = dh::ToSpan(result_buffer_); s_result_buffer_ = dh::ToSpan(result_buffer_);
} }
@ -156,7 +153,7 @@ void FeatureInteractionConstraint::Reset() {
} }
__global__ void ClearBuffersKernel( __global__ void ClearBuffersKernel(
BitField result_buffer_output, BitField result_buffer_input) { LBitField64 result_buffer_output, LBitField64 result_buffer_input) {
auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < result_buffer_output.Size()) { if (tid < result_buffer_output.Size()) {
result_buffer_output.Clear(tid); result_buffer_output.Clear(tid);
@ -185,7 +182,7 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
thrust::counting_iterator<int32_t> begin(0); thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size()); thrust::counting_iterator<int32_t> end(result_buffer_.size());
auto p_result_buffer = result_buffer_.data(); auto p_result_buffer = result_buffer_.data();
BitField node_constraints = s_node_constraints_[node_id]; LBitField64 node_constraints = s_node_constraints_[node_id];
thrust::device_ptr<int32_t> const out_end = thrust::copy_if( thrust::device_ptr<int32_t> const out_end = thrust::copy_if(
thrust::device, thrust::device,
@ -201,16 +198,16 @@ common::Span<int32_t> FeatureInteractionConstraint::QueryNode(int32_t node_id) {
} }
__global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input, __global__ void SetInputBufferKernel(common::Span<int32_t> feature_list_input,
BitField result_buffer_input) { LBitField64 result_buffer_input) {
uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x; uint32_t tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < feature_list_input.size()) { if (tid < feature_list_input.size()) {
result_buffer_input.Set(feature_list_input[tid]); result_buffer_input.Set(feature_list_input[tid]);
} }
} }
__global__ void QueryFeatureListKernel(BitField node_constraints, __global__ void QueryFeatureListKernel(LBitField64 node_constraints,
BitField result_buffer_input, LBitField64 result_buffer_input,
BitField result_buffer_output) { LBitField64 result_buffer_output) {
result_buffer_output |= node_constraints; result_buffer_output |= node_constraints;
result_buffer_output &= result_buffer_input; result_buffer_output &= result_buffer_input;
} }
@ -223,7 +220,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
ClearBuffers(); ClearBuffers();
BitField node_constraints = s_node_constraints_[nid]; LBitField64 node_constraints = s_node_constraints_[nid];
CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size()); CHECK_EQ(input_buffer_bits_.Size(), output_buffer_bits_.Size());
int constexpr kBlockThreads = 256; int constexpr kBlockThreads = 256;
@ -237,7 +234,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
thrust::counting_iterator<int32_t> begin(0); thrust::counting_iterator<int32_t> begin(0);
thrust::counting_iterator<int32_t> end(result_buffer_.size()); thrust::counting_iterator<int32_t> end(result_buffer_.size());
BitField local_result_buffer = output_buffer_bits_; LBitField64 local_result_buffer = output_buffer_bits_;
thrust::device_ptr<int32_t> const out_end = thrust::copy_if( thrust::device_ptr<int32_t> const out_end = thrust::copy_if(
thrust::device, thrust::device,
@ -257,7 +254,7 @@ common::Span<int32_t> FeatureInteractionConstraint::Query(
// Find interaction sets for each feature, then store all features in // Find interaction sets for each feature, then store all features in
// those sets in a buffer. // those sets in a buffer.
__global__ void RestoreFeatureListFromSetsKernel( __global__ void RestoreFeatureListFromSetsKernel(
BitField feature_buffer, LBitField64 feature_buffer,
int32_t fid, int32_t fid,
common::Span<int32_t> feature_interactions, common::Span<int32_t> feature_interactions,
@ -283,11 +280,11 @@ __global__ void RestoreFeatureListFromSetsKernel(
} }
} }
__global__ void InteractionConstraintSplitKernel(BitField feature, __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
int32_t feature_id, int32_t feature_id,
BitField node, LBitField64 node,
BitField left, LBitField64 left,
BitField right) { LBitField64 right) {
auto tid = threadIdx.x + blockDim.x * blockIdx.x; auto tid = threadIdx.x + blockDim.x * blockIdx.x;
if (tid > node.Size()) { if (tid > node.Size()) {
return; return;
@ -324,9 +321,9 @@ void FeatureInteractionConstraint::Split(
CHECK_LT(right_id, s_node_constraints_.size()); CHECK_LT(right_id, s_node_constraints_.size());
CHECK_NE(s_node_constraints_.size(), 0); CHECK_NE(s_node_constraints_.size(), 0);
BitField node = s_node_constraints_[node_id]; LBitField64 node = s_node_constraints_[node_id];
BitField left = s_node_constraints_[left_id]; LBitField64 left = s_node_constraints_[left_id];
BitField right = s_node_constraints_[right_id]; LBitField64 right = s_node_constraints_[right_id];
dim3 const block3(16, 64, 1); dim3 const block3(16, 64, 1);
dim3 const grid3(common::DivRoundUp(n_sets_, 16), dim3 const grid3(common::DivRoundUp(n_sets_, 16),

View File

@ -13,7 +13,7 @@
#include "param.h" #include "param.h"
#include "../common/span.h" #include "../common/span.h"
#include "../common/bitfield.cuh" #include "../common/bitfield.h"
#include "../common/device_helpers.cuh" #include "../common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
@ -115,25 +115,25 @@ struct FeatureInteractionConstraint {
// Allowed features attached to each node, have n_nodes bitfields, // Allowed features attached to each node, have n_nodes bitfields,
// each of size n_features. // each of size n_features.
std::vector<dh::device_vector<BitField::value_type>> node_constraints_storage_; std::vector<dh::device_vector<LBitField64::value_type>> node_constraints_storage_;
std::vector<BitField> node_constraints_; std::vector<LBitField64> node_constraints_;
common::Span<BitField> s_node_constraints_; common::Span<LBitField64> s_node_constraints_;
// buffer storing return feature list from Query, of size n_features. // buffer storing return feature list from Query, of size n_features.
dh::device_vector<int32_t> result_buffer_; dh::device_vector<int32_t> result_buffer_;
common::Span<int32_t> s_result_buffer_; common::Span<int32_t> s_result_buffer_;
// Temp buffers, one bit for each possible feature. // Temp buffers, one bit for each possible feature.
dh::device_vector<BitField::value_type> output_buffer_bits_storage_; dh::device_vector<LBitField64::value_type> output_buffer_bits_storage_;
BitField output_buffer_bits_; LBitField64 output_buffer_bits_;
dh::device_vector<BitField::value_type> input_buffer_bits_storage_; dh::device_vector<LBitField64::value_type> input_buffer_bits_storage_;
BitField input_buffer_bits_; LBitField64 input_buffer_bits_;
/* /*
* Combined features from all interaction sets that one feature belongs to. * Combined features from all interaction sets that one feature belongs to.
* For an input with [[0, 1], [1, 2]], the feature 1 belongs to sets {0, 1} * For an input with [[0, 1], [1, 2]], the feature 1 belongs to sets {0, 1}
*/ */
dh::device_vector<BitField::value_type> d_feature_buffer_storage_; dh::device_vector<LBitField64::value_type> d_feature_buffer_storage_;
BitField feature_buffer_; // of Size n features. LBitField64 feature_buffer_; // of Size n features.
// Clear out all temp buffers except for `feature_buffer_', which is // Clear out all temp buffers except for `feature_buffer_', which is
// handled in `Split'. // handled in `Split'.

View File

@ -0,0 +1,92 @@
/*!
* Copyright 2019 XGBoost contributors
*/
#include <gtest/gtest.h>
#include "../../../src/common/bitfield.h"
namespace xgboost {
TEST(BitField, Check) {
{
std::vector<LBitField64::value_type> storage(4, 0);
storage[2] = 2;
auto bits = LBitField64({storage.data(),
static_cast<typename common::Span<LBitField64::value_type>::index_type>(
storage.size())});
size_t true_bit = 190;
for (size_t i = true_bit + 1; i < bits.Size(); ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = 0; i < true_bit; ++i) {
ASSERT_FALSE(bits.Check(i));
}
}
{
std::vector<RBitField8::value_type> storage(4, 0);
storage[2] = 1 << 3;
auto bits = RBitField8({storage.data(),
static_cast<typename common::Span<RBitField8::value_type>::index_type>(
storage.size())});
size_t true_bit = 19;
for (size_t i = 0; i < true_bit; ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = true_bit + 1; i < bits.Size(); ++i) {
ASSERT_FALSE(bits.Check(i));
}
}
}
template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
void TestBitFieldSet(typename BitFieldT::value_type res, size_t index, size_t true_bit) {
using IndexT = typename common::Span<VT>::index_type;
std::vector<VT> storage(4, 0);
auto bits = BitFieldT({storage.data(), static_cast<IndexT>(storage.size())});
bits.Set(true_bit);
for (size_t i = 0; i < true_bit; ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_TRUE(bits.Check(true_bit));
for (size_t i = true_bit + 1; i < storage.size() * BitFieldT::kValueSize; ++i) {
ASSERT_FALSE(bits.Check(i));
}
ASSERT_EQ(storage[index], res);
}
TEST(BitField, Set) {
{
TestBitFieldSet<LBitField64>(2, 2, 190);
}
{
TestBitFieldSet<RBitField8>(1 << 3, 2, 19);
}
}
template <typename BitFieldT, typename VT = typename BitFieldT::value_type>
void TestBitFieldClear(size_t clear_bit) {
using IndexT = typename common::Span<VT>::index_type;
std::vector<VT> storage(4, 0);
auto bits = BitFieldT({storage.data(), static_cast<IndexT>(storage.size())});
bits.Set(clear_bit);
bits.Clear(clear_bit);
ASSERT_FALSE(bits.Check(clear_bit));
}
TEST(BitField, Clear) {
{
TestBitFieldClear<LBitField64>(190);
}
{
TestBitFieldClear<RBitField8>(19);
}
}
} // namespace xgboost

View File

@ -5,56 +5,55 @@
#include <thrust/copy.h> #include <thrust/copy.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <vector> #include <vector>
#include "../../../src/common/bitfield.cuh" #include "../../../src/common/bitfield.h"
#include "../../../src/common/device_helpers.cuh" #include "../../../src/common/device_helpers.cuh"
namespace xgboost { namespace xgboost {
__global__ void TestSetKernel(BitField bits) { __global__ void TestSetKernel(LBitField64 bits) {
auto tid = threadIdx.x + blockIdx.x * blockDim.x; auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid < bits.Size()) { if (tid < bits.Size()) {
bits.Set(tid); bits.Set(tid);
} }
} }
TEST(BitField, Set) { TEST(BitField, GPU_Set) {
dh::device_vector<BitField::value_type> storage; dh::device_vector<LBitField64::value_type> storage;
uint32_t constexpr kBits = 128; uint32_t constexpr kBits = 128;
storage.resize(128); storage.resize(128);
auto bits = BitField(dh::ToSpan(storage)); auto bits = LBitField64(dh::ToSpan(storage));
TestSetKernel<<<1, kBits>>>(bits); TestSetKernel<<<1, kBits>>>(bits);
std::vector<BitField::value_type> h_storage(storage.size()); std::vector<LBitField64::value_type> h_storage(storage.size());
thrust::copy(storage.begin(), storage.end(), h_storage.begin()); thrust::copy(storage.begin(), storage.end(), h_storage.begin());
BitField outputs { LBitField64 outputs {
common::Span<BitField::value_type>{h_storage.data(), common::Span<LBitField64::value_type>{h_storage.data(),
h_storage.data() + h_storage.size()}}; h_storage.data() + h_storage.size()}};
for (size_t i = 0; i < kBits; ++i) { for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i)); ASSERT_TRUE(outputs.Check(i));
} }
} }
__global__ void TestOrKernel(BitField lhs, BitField rhs) { __global__ void TestOrKernel(LBitField64 lhs, LBitField64 rhs) {
lhs |= rhs; lhs |= rhs;
} }
TEST(BitField, And) { TEST(BitField, GPU_And) {
uint32_t constexpr kBits = 128; uint32_t constexpr kBits = 128;
dh::device_vector<BitField::value_type> lhs_storage(kBits); dh::device_vector<LBitField64::value_type> lhs_storage(kBits);
dh::device_vector<BitField::value_type> rhs_storage(kBits); dh::device_vector<LBitField64::value_type> rhs_storage(kBits);
auto lhs = BitField(dh::ToSpan(lhs_storage)); auto lhs = LBitField64(dh::ToSpan(lhs_storage));
auto rhs = BitField(dh::ToSpan(rhs_storage)); auto rhs = LBitField64(dh::ToSpan(rhs_storage));
thrust::fill(lhs_storage.begin(), lhs_storage.end(), 0UL); thrust::fill(lhs_storage.begin(), lhs_storage.end(), 0UL);
thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast<BitField::value_type>(0UL)); thrust::fill(rhs_storage.begin(), rhs_storage.end(), ~static_cast<LBitField64::value_type>(0UL));
TestOrKernel<<<1, kBits>>>(lhs, rhs); TestOrKernel<<<1, kBits>>>(lhs, rhs);
std::vector<BitField::value_type> h_storage(lhs_storage.size()); std::vector<LBitField64::value_type> h_storage(lhs_storage.size());
thrust::copy(lhs_storage.begin(), lhs_storage.end(), h_storage.begin()); thrust::copy(lhs_storage.begin(), lhs_storage.end(), h_storage.begin());
BitField outputs {{h_storage.data(), h_storage.data() + h_storage.size()}}; LBitField64 outputs {{h_storage.data(), h_storage.data() + h_storage.size()}};
for (size_t i = 0; i < kBits; ++i) { for (size_t i = 0; i < kBits; ++i) {
ASSERT_TRUE(outputs.Check(i)); ASSERT_TRUE(outputs.Check(i));
} }
} }
} // namespace xgboost } // namespace xgboost

View File

@ -0,0 +1,57 @@
/*! Copyright 2019 by Contributors */
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <thrust/device_vector.h>
#include "../../../src/common/device_helpers.cuh"
namespace xgboost {
TEST(MetaInfo, FromInterface) {
cudaSetDevice(0);
constexpr size_t kRows = 16;
thrust::device_vector<float> d_data(kRows);
for (size_t i = 0; i < d_data.size(); ++i) {
d_data[i] = i * 2.0;
}
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
auto p_d_data = dh::Raw(d_data);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
column["data"] = j_data;
std::stringstream ss;
Json::Dump(column, &ss);
std::string str = ss.str();
MetaInfo info;
info.SetInfo("label", str.c_str());
auto const& h_label = info.labels_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_label[i], d_data[i]);
}
info.SetInfo("weight", str.c_str());
auto const& h_weight = info.weights_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_weight[i], d_data[i]);
}
info.SetInfo("base_margin", str.c_str());
auto const& h_base_margin = info.base_margin_.HostVector();
for (size_t i = 0; i < d_data.size(); ++i) {
ASSERT_EQ(h_base_margin[i], d_data[i]);
}
}
} // namespace xgboost

View File

@ -1,10 +1,15 @@
// Copyright by Contributors // Copyright by Contributors
#include <xgboost/data.h> #include <gtest/gtest.h>
#include <dmlc/filesystem.h> #include <dmlc/filesystem.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include "../../../src/data/simple_csr_source.h" #include "../../../src/data/simple_csr_source.h"
#include "../helpers.h" #include "../helpers.h"
namespace xgboost {
TEST(SimpleCSRSource, SaveLoadBinary) { TEST(SimpleCSRSource, SaveLoadBinary) {
dmlc::TemporaryDirectory tempdir; dmlc::TemporaryDirectory tempdir;
const std::string tmp_file = tempdir.path + "/simple.libsvm"; const std::string tmp_file = tempdir.path + "/simple.libsvm";
@ -33,3 +38,4 @@ TEST(SimpleCSRSource, SaveLoadBinary) {
delete dmat; delete dmat;
delete dmat_read; delete dmat_read;
} }
} // namespace xgboost

View File

@ -0,0 +1,216 @@
// Copyright (c) 2019 by Contributors
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/json.h>
#include <thrust/device_vector.h>
#include <memory>
#include "../../../src/common/bitfield.h"
#include "../../../src/common/device_helpers.cuh"
#include "../../../src/data/simple_csr_source.h"
namespace xgboost {
TEST(SimpleCSRSource, FromColumnarDense) {
constexpr size_t kRows = 16;
Json column { Object() };
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["strides"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(4)))});
thrust::device_vector<float> d_data(kRows);
for (size_t i = 0; i < d_data.size(); ++i) {
d_data[i] = i * 2.0;
}
auto p_d_data = dh::Raw(d_data);
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
column["data"] = j_data;
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
Json column_arr {Array{std::vector<Json>{column}}};
std::stringstream ss;
Json::Dump(column_arr, &ss);
std::string str = ss.str();
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(str.c_str());
auto const& data = source->page_.data.HostVector();
auto const& offset = source->page_.offset.HostVector();
for (size_t i = 0; i < kRows; ++i) {
auto e = data[i];
ASSERT_NEAR(e.fvalue, i * 2.0, kRtEps);
ASSERT_EQ(e.index, 0); // feature 0
}
ASSERT_EQ(offset.back(), 16);
for (size_t i = 0; i < kRows + 1; ++i) {
ASSERT_EQ(offset[i], i);
}
}
TEST(SimpleCSRSource, FromColumnarWithEmptyRows) {
// In this test we construct a data storage similar to cudf
constexpr size_t kRows = 102;
constexpr size_t kCols = 24;
constexpr size_t kMissingRows = 3;
std::vector<Json> v_columns (kCols);
std::vector<dh::device_vector<float>> columns_data(kCols);
std::vector<dh::device_vector<unsigned char>> column_bitfields(kCols);
unsigned char constexpr kUCOne = 1;
for (size_t i = 0; i < kCols; ++i) {
auto& col = v_columns[i];
col = Object();
auto& data = columns_data[i];
data.resize(kRows);
thrust::sequence(data.begin(), data.end(), 0);
dh::safe_cuda(cudaDeviceSynchronize());
dh::safe_cuda(cudaGetLastError());
ASSERT_EQ(data.size(), kRows);
auto p_d_data = raw_pointer_cast(data.data());
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
col["data"] = j_data;
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
col["shape"] = Array(j_shape);
col["version"] = Integer(static_cast<Integer::Int>(1));
col["typestr"] = String("<f4");
// Construct the mask object.
col["mask"] = Object();
auto& j_mask = col["mask"];
auto& mask_storage = column_bitfields[i];
mask_storage.resize(16); // 16 bytes
mask_storage[0] = ~(kUCOne << 2); // 3^th row is missing
mask_storage[1] = ~(kUCOne << 3); // 12^th row is missing
size_t last_ind = 12;
mask_storage[last_ind] = ~(kUCOne << 5);
std::set<size_t> missing_row_index {0, 1, last_ind};
for (size_t i = 0; i < mask_storage.size(); ++i) {
if (missing_row_index.find(i) == missing_row_index.cend()) {
// all other rows are valid
mask_storage[i] = ~0;
}
}
j_mask["data"] = std::vector<Json>{
Json(Integer(reinterpret_cast<Integer::Int>(mask_storage.data().get()))),
Json(Boolean(false))};
j_mask["shape"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(16)))});
j_mask["typestr"] = String("|i1");
j_mask["null_count"] = Json(Integer(static_cast<Integer::Int>(kMissingRows)));
}
Json column_arr {Array(v_columns)};
std::stringstream ss;
Json::Dump(column_arr, &ss);
std::string str = ss.str();
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(str.c_str());
auto const& data = source->page_.data.HostVector();
auto const& offset = source->page_.offset.HostVector();
ASSERT_EQ(offset.size(), kRows + 1);
for (size_t i = 1; i < offset.size(); ++i) {
for (size_t j = offset[i-1]; j < offset[i]; ++j) {
ASSERT_EQ(data[j].index, j % kCols);
ASSERT_NEAR(data[j].fvalue, i - 1, kRtEps);
}
}
}
TEST(SimpleCSRSource, FromColumnarSparse) {
constexpr size_t kRows = 32;
constexpr size_t kCols = 2;
unsigned char constexpr kUCOne = 1;
std::vector<dh::device_vector<float>> columns_data(kCols);
std::vector<dh::device_vector<unsigned char>> column_bitfields(kCols);
{
// column 0
auto& mask = column_bitfields[0];
mask.resize(8);
for (size_t j = 0; j < mask.size(); ++j) {
mask[j] = ~0;
}
mask[0] = ~(kUCOne << 2);
}
{
// column 1
auto& mask = column_bitfields[1];
mask.resize(8);
for (size_t j = 0; j < mask.size(); ++j) {
mask[j] = ~0;
}
mask[2] = ~(kUCOne << 3);
}
for (size_t c = 0; c < kCols; ++c) {
columns_data[c].resize(kRows);
thrust::sequence(columns_data[c].begin(), columns_data[c].end(), 0);
}
std::vector<Json> j_columns(kCols);
for (size_t c = 0; c < kCols; ++c) {
auto& column = j_columns[c];
column = Object();
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
auto p_d_data = raw_pointer_cast(columns_data[c].data());
std::vector<Json> j_data {
Json(Integer(reinterpret_cast<Integer::Int>(p_d_data))),
Json(Boolean(false))};
column["data"] = j_data;
std::vector<Json> j_shape {Json(Integer(static_cast<Integer::Int>(kRows)))};
column["shape"] = Array(j_shape);
column["version"] = Integer(static_cast<Integer::Int>(1));
column["typestr"] = String("<f4");
column["mask"] = Object();
auto& j_mask = column["mask"];
j_mask["data"] = std::vector<Json>{
Json(Integer(reinterpret_cast<Integer::Int>(column_bitfields[c].data().get()))),
Json(Boolean(false))};
j_mask["shape"] = Array(std::vector<Json>{Json(Integer(static_cast<Integer::Int>(8)))});
j_mask["typestr"] = String("|i1");
j_mask["null_count"] = Json(Integer(static_cast<Integer::Int>(1)));
}
Json column_arr {Array(j_columns)};
std::stringstream ss;
Json::Dump(column_arr, &ss);
std::string str = ss.str();
std::unique_ptr<data::SimpleCSRSource> source (new data::SimpleCSRSource());
source->CopyFrom(str.c_str());
auto const& data = source->page_.data.HostVector();
auto const& offset = source->page_.offset.HostVector();
ASSERT_EQ(offset.size(), kRows + 1);
ASSERT_EQ(data[4].index, 1);
ASSERT_EQ(data[4].fvalue, 2);
ASSERT_EQ(data[37].index, 0);
ASSERT_EQ(data[37].fvalue, 19);
}
} // namespace xgboost

View File

@ -16,7 +16,7 @@ namespace xgboost {
namespace { namespace {
struct FConstraintWrapper : public FeatureInteractionConstraint { struct FConstraintWrapper : public FeatureInteractionConstraint {
common::Span<BitField> GetNodeConstraints() { common::Span<LBitField64> GetNodeConstraints() {
return FeatureInteractionConstraint::s_node_constraints_; return FeatureInteractionConstraint::s_node_constraints_;
} }
FConstraintWrapper(tree::TrainParam param, int32_t n_features) : FConstraintWrapper(tree::TrainParam param, int32_t n_features) :
@ -44,13 +44,13 @@ tree::TrainParam GetParameter() {
return param; return param;
} }
void CompareBitField(BitField d_field, std::set<uint32_t> positions) { void CompareBitField(LBitField64 d_field, std::set<uint32_t> positions) {
std::vector<BitField::value_type> h_field_storage(d_field.bits_.size()); std::vector<LBitField64::value_type> h_field_storage(d_field.bits_.size());
thrust::copy(thrust::device_ptr<BitField::value_type>(d_field.bits_.data()), thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_field.bits_.data()),
thrust::device_ptr<BitField::value_type>( thrust::device_ptr<LBitField64::value_type>(
d_field.bits_.data() + d_field.bits_.size()), d_field.bits_.data() + d_field.bits_.size()),
h_field_storage.data()); h_field_storage.data());
BitField h_field; LBitField64 h_field;
h_field.bits_ = {h_field_storage.data(), h_field_storage.data() + h_field_storage.size()}; h_field.bits_ = {h_field_storage.data(), h_field_storage.data() + h_field_storage.size()};
for (size_t i = 0; i < h_field.Size(); ++i) { for (size_t i = 0; i < h_field.Size(); ++i) {
@ -71,14 +71,14 @@ TEST(FeatureInteractionConstraint, Init) {
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
ASSERT_EQ(constraints.Features(), kFeatures); ASSERT_EQ(constraints.Features(), kFeatures);
common::Span<BitField> s_nodes_constraints = constraints.GetNodeConstraints(); common::Span<LBitField64> s_nodes_constraints = constraints.GetNodeConstraints();
for (BitField const& d_node : s_nodes_constraints) { for (LBitField64 const& d_node : s_nodes_constraints) {
std::vector<BitField::value_type> h_node_storage(d_node.bits_.size()); std::vector<LBitField64::value_type> h_node_storage(d_node.bits_.size());
thrust::copy(thrust::device_ptr<BitField::value_type>(d_node.bits_.data()), thrust::copy(thrust::device_ptr<LBitField64::value_type>(d_node.bits_.data()),
thrust::device_ptr<BitField::value_type>( thrust::device_ptr<LBitField64::value_type>(
d_node.bits_.data() + d_node.bits_.size()), d_node.bits_.data() + d_node.bits_.size()),
h_node_storage.data()); h_node_storage.data());
BitField h_node; LBitField64 h_node;
h_node.bits_ = {h_node_storage.data(), h_node_storage.data() + h_node_storage.size()}; h_node.bits_ = {h_node_storage.data(), h_node_storage.data() + h_node_storage.size()};
// no feature is attached to node. // no feature is attached to node.
for (size_t i = 0; i < h_node.Size(); ++i) { for (size_t i = 0; i < h_node.Size(); ++i) {
@ -108,7 +108,7 @@ TEST(FeatureInteractionConstraint, Init) {
} }
{ {
// Test having more than 1 BitField::value_type // Test having more than 1 LBitField64::value_type
int32_t constexpr kFeatures = 129; int32_t constexpr kFeatures = 129;
tree::TrainParam param = GetParameter(); tree::TrainParam param = GetParameter();
param.interaction_constraints = R"([[0, 1, 3], [3, 5, 128], [127, 128]])"; param.interaction_constraints = R"([[0, 1, 3], [3, 5, 128], [127, 128]])";
@ -129,7 +129,7 @@ TEST(FeatureInteractionConstraint, Split) {
FConstraintWrapper constraints(param, kFeatures); FConstraintWrapper constraints(param, kFeatures);
{ {
BitField d_node[3]; LBitField64 d_node[3];
constraints.Split(0, /*feature_id=*/1, 1, 2); constraints.Split(0, /*feature_id=*/1, 1, 2);
for (size_t nid = 0; nid < 3; ++nid) { for (size_t nid = 0; nid < 3; ++nid) {
d_node[nid] = constraints.GetNodeConstraints()[nid]; d_node[nid] = constraints.GetNodeConstraints()[nid];
@ -139,7 +139,7 @@ TEST(FeatureInteractionConstraint, Split) {
} }
{ {
BitField d_node[5]; LBitField64 d_node[5];
constraints.Split(1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4); constraints.Split(1, /*feature_id=*/0, /*left_id=*/3, /*right_id=*/4);
for (auto nid : {1, 3, 4}) { for (auto nid : {1, 3, 4}) {
d_node[nid] = constraints.GetNodeConstraints()[nid]; d_node[nid] = constraints.GetNodeConstraints()[nid];

View File

@ -0,0 +1,44 @@
import numpy as np
import xgboost as xgb
import sys
import pytest
sys.path.append("tests/python")
import testing as tm
pytestmark = pytest.mark.skipif(**tm.no_cudf())
class TestFromColumnar:
'''Tests for constructing DMatrix from data structure conforming Apache
Arrow specification.'''
@pytest.mark.skipif(**tm.no_cudf())
def test_from_cudf():
'''Test constructing DMatrix from cudf'''
import cudf
import pandas as pd
kRows = 80
kCols = 2
na = np.random.randn(kRows, kCols).astype(np.float32)
na[3, 1] = np.NAN
na[5, 0] = np.NAN
pa = pd.DataFrame(na)
np_label = np.random.randn(kRows).astype(np.float32)
pa_label = pd.DataFrame(np_label)
names = []
for i in range(0, kCols):
names.append(str(i))
pa.columns = names
cd: cudf.DataFrame = cudf.from_pandas(pa)
cd_label: cudf.DataFrame = cudf.from_pandas(pa_label)
dtrain = xgb.DMatrix(cd, label=cd_label)
assert dtrain.num_col() == kCols
assert dtrain.num_row() == kRows

View File

@ -1,5 +1,6 @@
# coding: utf-8 # coding: utf-8
from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED, DASK_INSTALLED from xgboost.compat import SKLEARN_INSTALLED, PANDAS_INSTALLED, DT_INSTALLED
from xgboost.compat import CUDF_INSTALLED, DASK_INSTALLED
def no_sklearn(): def no_sklearn():
@ -31,3 +32,8 @@ def no_matplotlib():
except ImportError: except ImportError:
return {'condition': True, return {'condition': True,
'reason': reason} 'reason': reason}
def no_cudf():
return {'condition': not CUDF_INSTALLED,
'reason': 'CUDF is not installed'}