Enhance inplace prediction. (#6653)
* Accept array interface for csr and array. * Accept an optional proxy dmatrix for metainfo. This constructs an explicit `_ProxyDMatrix` type in Python. * Remove unused doc. * Add strict output.
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2019~2020 by Contributors
|
||||
* Copyright (c) 2019~2021 by Contributors
|
||||
* \file adapter.h
|
||||
*/
|
||||
#ifndef XGBOOST_DATA_ADAPTER_H_
|
||||
@@ -228,6 +228,128 @@ class DenseAdapter : public detail::SingleBatchDataIter<DenseAdapterBatch> {
|
||||
size_t num_columns_;
|
||||
};
|
||||
|
||||
class ArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface array_interface_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface array_interface_;
|
||||
size_t ridx_;
|
||||
public:
|
||||
Line(ArrayInterface array_interface, size_t ridx)
|
||||
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}
|
||||
|
||||
size_t Size() const { return array_interface_.num_cols; }
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, idx, array_interface_.GetElement(idx)};
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
ArrayAdapterBatch() = default;
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto line = array_interface_.SliceRow(idx);
|
||||
return Line{line, idx};
|
||||
}
|
||||
|
||||
explicit ArrayAdapterBatch(ArrayInterface array_interface)
|
||||
: array_interface_{std::move(array_interface)} {}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for dense array on host, in Python that's `numpy.ndarray`. This is similar to
|
||||
* `DenseAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
|
||||
public:
|
||||
explicit ArrayAdapter(StringView array_interface) {
|
||||
auto j = Json::Load(array_interface);
|
||||
array_interface_ = ArrayInterface(get<Object const>(j));
|
||||
batch_ = ArrayAdapterBatch{array_interface_};
|
||||
}
|
||||
ArrayAdapterBatch const& Value() const override { return batch_; }
|
||||
size_t NumRows() const { return array_interface_.num_rows; }
|
||||
size_t NumColumns() const { return array_interface_.num_cols; }
|
||||
|
||||
private:
|
||||
ArrayAdapterBatch batch_;
|
||||
ArrayInterface array_interface_;
|
||||
};
|
||||
|
||||
class CSRArrayAdapterBatch : public detail::NoMetaInfo {
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
|
||||
class Line {
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t ridx_;
|
||||
|
||||
public:
|
||||
Line(ArrayInterface indices, ArrayInterface values, size_t ridx)
|
||||
: indices_{std::move(indices)}, values_{std::move(values)}, ridx_{ridx} {}
|
||||
|
||||
COOTuple GetElement(size_t idx) const {
|
||||
return {ridx_, indices_.GetElement<size_t>(idx), values_.GetElement(idx)};
|
||||
}
|
||||
size_t Size() const {
|
||||
return values_.num_rows * values_.num_cols;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
CSRArrayAdapterBatch() = default;
|
||||
CSRArrayAdapterBatch(ArrayInterface indptr, ArrayInterface indices,
|
||||
ArrayInterface values)
|
||||
: indptr_{std::move(indptr)}, indices_{std::move(indices)},
|
||||
values_{std::move(values)} {}
|
||||
|
||||
Line const GetLine(size_t idx) const {
|
||||
auto begin_offset = indptr_.GetElement<size_t>(idx);
|
||||
auto end_offset = indptr_.GetElement<size_t>(idx + 1);
|
||||
auto indices = indices_.SliceOffset(begin_offset);
|
||||
auto values = values_.SliceOffset(begin_offset);
|
||||
values.num_cols = end_offset - begin_offset;
|
||||
values.num_rows = 1;
|
||||
indices.num_cols = values.num_cols;
|
||||
indices.num_rows = values.num_rows;
|
||||
return Line{indices, values, idx};
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Adapter for CSR array on host, in Python that's `scipy.sparse.csr_matrix`. This is
|
||||
* similar to `CSRAdapter`, but supports __array_interface__ instead of raw pointers. An
|
||||
* advantage is this can handle various data type without making a copy.
|
||||
*/
|
||||
class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch> {
|
||||
public:
|
||||
CSRArrayAdapter(StringView indptr, StringView indices, StringView values,
|
||||
size_t num_cols)
|
||||
: indptr_{indptr}, indices_{indices}, values_{values}, num_cols_{num_cols} {
|
||||
batch_ = CSRArrayAdapterBatch{indptr_, indices_, values_};
|
||||
}
|
||||
|
||||
CSRArrayAdapterBatch const& Value() const override {
|
||||
return batch_;
|
||||
}
|
||||
size_t NumRows() const {
|
||||
size_t size = indptr_.num_cols * indptr_.num_rows;
|
||||
size = size == 0 ? 0 : size - 1;
|
||||
return size;
|
||||
}
|
||||
size_t NumColumns() const { return num_cols_; }
|
||||
|
||||
private:
|
||||
CSRArrayAdapterBatch batch_;
|
||||
ArrayInterface indptr_;
|
||||
ArrayInterface indices_;
|
||||
ArrayInterface values_;
|
||||
size_t num_cols_;
|
||||
};
|
||||
|
||||
class CSCAdapterBatch : public detail::NoMetaInfo {
|
||||
public:
|
||||
CSCAdapterBatch(const size_t* col_ptr, const unsigned* row_idx,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by Contributors
|
||||
* Copyright 2019-2021 by Contributors
|
||||
* \file array_interface.h
|
||||
* \brief View of __array_interface__
|
||||
*/
|
||||
@@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
|
||||
}
|
||||
}
|
||||
|
||||
static std::string UnSupportedType(const char (&typestr)[3]) {
|
||||
static std::string UnSupportedType(StringView typestr) {
|
||||
return TypeStr(typestr[1]) + " is not supported.";
|
||||
}
|
||||
};
|
||||
@@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
|
||||
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
|
||||
Validate(column);
|
||||
@@ -257,16 +258,24 @@ class ArrayInterface {
|
||||
}
|
||||
|
||||
auto typestr = get<String const>(column.at("typestr"));
|
||||
type[0] = typestr.at(0);
|
||||
type[1] = typestr.at(1);
|
||||
type[2] = typestr.at(2);
|
||||
this->CheckType();
|
||||
this->AssignType(StringView{typestr});
|
||||
}
|
||||
|
||||
public:
|
||||
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
|
||||
|
||||
public:
|
||||
ArrayInterface() = default;
|
||||
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load({str.c_str(), str.size()});
|
||||
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
|
||||
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
explicit ArrayInterface(StringView str, bool allow_mask = true) {
|
||||
auto jinterface = Json::Load(str);
|
||||
if (IsA<Object>(jinterface)) {
|
||||
this->Initialize(get<Object const>(jinterface), allow_mask);
|
||||
return;
|
||||
@@ -279,71 +288,114 @@ class ArrayInterface {
|
||||
}
|
||||
}
|
||||
|
||||
explicit ArrayInterface(std::map<std::string, Json> const &column,
|
||||
bool allow_mask = true) {
|
||||
this->Initialize(column, allow_mask);
|
||||
}
|
||||
|
||||
void CheckType() const {
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
return;
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
return;
|
||||
void AssignType(StringView typestr) {
|
||||
if (typestr[1] == 'f' && typestr[2] == '4') {
|
||||
type = kF4;
|
||||
} else if (typestr[1] == 'f' && typestr[2] == '8') {
|
||||
type = kF8;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '1') {
|
||||
type = kI1;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '2') {
|
||||
type = kI2;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '4') {
|
||||
type = kI4;
|
||||
} else if (typestr[1] == 'i' && typestr[2] == '8') {
|
||||
type = kI8;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '1') {
|
||||
type = kU1;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '2') {
|
||||
type = kU2;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '4') {
|
||||
type = kU4;
|
||||
} else if (typestr[1] == 'u' && typestr[2] == '8') {
|
||||
type = kU8;
|
||||
} else {
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
|
||||
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE float GetElement(size_t idx) const {
|
||||
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
|
||||
void* p_values;
|
||||
switch (type) {
|
||||
case kF4:
|
||||
p_values = reinterpret_cast<float *>(data) + offset;
|
||||
break;
|
||||
case kF8:
|
||||
p_values = reinterpret_cast<double *>(data) + offset;
|
||||
break;
|
||||
case kI1:
|
||||
p_values = reinterpret_cast<int8_t *>(data) + offset;
|
||||
break;
|
||||
case kI2:
|
||||
p_values = reinterpret_cast<int16_t *>(data) + offset;
|
||||
break;
|
||||
case kI4:
|
||||
p_values = reinterpret_cast<int32_t *>(data) + offset;
|
||||
break;
|
||||
case kI8:
|
||||
p_values = reinterpret_cast<int64_t *>(data) + offset;
|
||||
break;
|
||||
case kU1:
|
||||
p_values = reinterpret_cast<uint8_t *>(data) + offset;
|
||||
break;
|
||||
case kU2:
|
||||
p_values = reinterpret_cast<uint16_t *>(data) + offset;
|
||||
break;
|
||||
case kU4:
|
||||
p_values = reinterpret_cast<uint32_t *>(data) + offset;
|
||||
break;
|
||||
case kU8:
|
||||
p_values = reinterpret_cast<uint64_t *>(data) + offset;
|
||||
break;
|
||||
}
|
||||
ArrayInterface ret = *this;
|
||||
ret.data = p_values;
|
||||
return ret;
|
||||
}
|
||||
|
||||
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
|
||||
size_t offset = idx * num_cols;
|
||||
auto ret = this->SliceOffset(offset);
|
||||
ret.num_rows = 1;
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename T = float>
|
||||
XGBOOST_DEVICE T GetElement(size_t idx) const {
|
||||
SPAN_CHECK(idx < num_cols * num_rows);
|
||||
if (type[1] == 'f' && type[2] == '4') {
|
||||
switch (type) {
|
||||
case kF4:
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
} else if (type[1] == 'f' && type[2] == '8') {
|
||||
case kF8:
|
||||
return reinterpret_cast<double*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '1') {
|
||||
case kI1:
|
||||
return reinterpret_cast<int8_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '2') {
|
||||
case kI2:
|
||||
return reinterpret_cast<int16_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '4') {
|
||||
case kI4:
|
||||
return reinterpret_cast<int32_t*>(data)[idx];
|
||||
} else if (type[1] == 'i' && type[2] == '8') {
|
||||
case kI8:
|
||||
return reinterpret_cast<int64_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '1') {
|
||||
case kU1:
|
||||
return reinterpret_cast<uint8_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '2') {
|
||||
case kU2:
|
||||
return reinterpret_cast<uint16_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '4') {
|
||||
case kU4:
|
||||
return reinterpret_cast<uint32_t*>(data)[idx];
|
||||
} else if (type[1] == 'u' && type[2] == '8') {
|
||||
case kU8:
|
||||
return reinterpret_cast<uint64_t*>(data)[idx];
|
||||
} else {
|
||||
SPAN_CHECK(false);
|
||||
return 0;
|
||||
}
|
||||
SPAN_CHECK(false);
|
||||
return reinterpret_cast<float*>(data)[idx];
|
||||
}
|
||||
|
||||
RBitField8 valid;
|
||||
bst_row_t num_rows;
|
||||
bst_feature_t num_cols;
|
||||
void* data;
|
||||
char type[3];
|
||||
|
||||
Type type;
|
||||
};
|
||||
|
||||
} // namespace xgboost
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright 2019 by XGBoost Contributors
|
||||
* Copyright 2019-2021 by XGBoost Contributors
|
||||
*
|
||||
* \file data.cu
|
||||
* \brief Handles setting metainfo from array interface.
|
||||
@@ -45,15 +45,15 @@ auto SetDeviceToPtr(void *ptr) {
|
||||
} // anonymous namespace
|
||||
|
||||
void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
|
||||
CHECK(column.type[1] == 'i' || column.type[1] == 'u')
|
||||
<< "Expected integer metainfo";
|
||||
CHECK(column.type != ArrayInterface::kF4 && column.type != ArrayInterface::kF8)
|
||||
<< "Expected integer for group info.";
|
||||
|
||||
auto ptr_device = SetDeviceToPtr(column.data);
|
||||
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
|
||||
auto d_tmp = temp.data();
|
||||
|
||||
dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
|
||||
d_tmp[idx] = column.GetElement(idx);
|
||||
d_tmp[idx] = column.GetElement<size_t>(idx);
|
||||
});
|
||||
auto length = column.num_rows;
|
||||
out->resize(length + 1);
|
||||
@@ -103,15 +103,15 @@ void MetaInfo::SetInfo(const char * c_key, std::string const& interface_str) {
|
||||
auto it = dh::MakeTransformIterator<uint32_t>(
|
||||
thrust::make_counting_iterator(0ul),
|
||||
[array_interface] __device__(size_t i) {
|
||||
return static_cast<uint32_t>(array_interface.GetElement(i));
|
||||
return array_interface.GetElement<uint32_t>(i);
|
||||
});
|
||||
dh::caching_device_vector<bool> flag(1);
|
||||
auto d_flag = dh::ToSpan(flag);
|
||||
auto d = SetDeviceToPtr(array_interface.data);
|
||||
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
|
||||
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
|
||||
if (static_cast<uint32_t>(array_interface.GetElement(i)) >
|
||||
static_cast<uint32_t>(array_interface.GetElement(i + 1))) {
|
||||
if (array_interface.GetElement<uint32_t>(i) >
|
||||
array_interface.GetElement<uint32_t>(i + 1)) {
|
||||
d_flag[0] = false;
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user