Enhance inplace prediction. (#6653)

* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.

* Remove unused doc.
* Add strict output.
This commit is contained in:
Jiaming Yuan
2021-02-02 11:41:46 +08:00
committed by GitHub
parent 87ab1ad607
commit 411592a347
22 changed files with 955 additions and 530 deletions

View File

@@ -1,5 +1,5 @@
/*!
* Copyright 2019 by Contributors
* Copyright 2019-2021 by Contributors
* \file array_interface.h
* \brief View of __array_interface__
*/
@@ -87,7 +87,7 @@ struct ArrayInterfaceErrors {
}
}
static std::string UnSupportedType(const char (&typestr)[3]) {
static std::string UnSupportedType(StringView typestr) {
return TypeStr(typestr[1]) + " is not supported.";
}
};
@@ -210,6 +210,7 @@ class ArrayInterfaceHandler {
static_cast<bst_feature_t>(get<Integer const>(j_shape.at(1)))};
}
}
template <typename T>
static common::Span<T> ExtractData(std::map<std::string, Json> const& column) {
Validate(column);
@@ -257,16 +258,24 @@ class ArrayInterface {
}
auto typestr = get<String const>(column.at("typestr"));
type[0] = typestr.at(0);
type[1] = typestr.at(1);
type[2] = typestr.at(2);
this->CheckType();
this->AssignType(StringView{typestr});
}
public:
enum Type : std::int8_t { kF4, kF8, kI1, kI2, kI4, kI8, kU1, kU2, kU4, kU8 };
public:
ArrayInterface() = default;
explicit ArrayInterface(std::string const& str, bool allow_mask = true) {
auto jinterface = Json::Load({str.c_str(), str.size()});
explicit ArrayInterface(std::string const &str, bool allow_mask = true)
: ArrayInterface{StringView{str.c_str(), str.size()}, allow_mask} {}
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
explicit ArrayInterface(StringView str, bool allow_mask = true) {
auto jinterface = Json::Load(str);
if (IsA<Object>(jinterface)) {
this->Initialize(get<Object const>(jinterface), allow_mask);
return;
@@ -279,71 +288,114 @@ class ArrayInterface {
}
}
explicit ArrayInterface(std::map<std::string, Json> const &column,
bool allow_mask = true) {
this->Initialize(column, allow_mask);
}
void CheckType() const {
if (type[1] == 'f' && type[2] == '4') {
return;
} else if (type[1] == 'f' && type[2] == '8') {
return;
} else if (type[1] == 'i' && type[2] == '1') {
return;
} else if (type[1] == 'i' && type[2] == '2') {
return;
} else if (type[1] == 'i' && type[2] == '4') {
return;
} else if (type[1] == 'i' && type[2] == '8') {
return;
} else if (type[1] == 'u' && type[2] == '1') {
return;
} else if (type[1] == 'u' && type[2] == '2') {
return;
} else if (type[1] == 'u' && type[2] == '4') {
return;
} else if (type[1] == 'u' && type[2] == '8') {
return;
void AssignType(StringView typestr) {
if (typestr[1] == 'f' && typestr[2] == '4') {
type = kF4;
} else if (typestr[1] == 'f' && typestr[2] == '8') {
type = kF8;
} else if (typestr[1] == 'i' && typestr[2] == '1') {
type = kI1;
} else if (typestr[1] == 'i' && typestr[2] == '2') {
type = kI2;
} else if (typestr[1] == 'i' && typestr[2] == '4') {
type = kI4;
} else if (typestr[1] == 'i' && typestr[2] == '8') {
type = kI8;
} else if (typestr[1] == 'u' && typestr[2] == '1') {
type = kU1;
} else if (typestr[1] == 'u' && typestr[2] == '2') {
type = kU2;
} else if (typestr[1] == 'u' && typestr[2] == '4') {
type = kU4;
} else if (typestr[1] == 'u' && typestr[2] == '8') {
type = kU8;
} else {
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(type);
LOG(FATAL) << ArrayInterfaceErrors::UnSupportedType(typestr);
return;
}
}
XGBOOST_DEVICE float GetElement(size_t idx) const {
XGBOOST_DEVICE ArrayInterface SliceOffset(size_t offset) const {
void* p_values;
switch (type) {
case kF4:
p_values = reinterpret_cast<float *>(data) + offset;
break;
case kF8:
p_values = reinterpret_cast<double *>(data) + offset;
break;
case kI1:
p_values = reinterpret_cast<int8_t *>(data) + offset;
break;
case kI2:
p_values = reinterpret_cast<int16_t *>(data) + offset;
break;
case kI4:
p_values = reinterpret_cast<int32_t *>(data) + offset;
break;
case kI8:
p_values = reinterpret_cast<int64_t *>(data) + offset;
break;
case kU1:
p_values = reinterpret_cast<uint8_t *>(data) + offset;
break;
case kU2:
p_values = reinterpret_cast<uint16_t *>(data) + offset;
break;
case kU4:
p_values = reinterpret_cast<uint32_t *>(data) + offset;
break;
case kU8:
p_values = reinterpret_cast<uint64_t *>(data) + offset;
break;
}
ArrayInterface ret = *this;
ret.data = p_values;
return ret;
}
XGBOOST_DEVICE ArrayInterface SliceRow(size_t idx) const {
size_t offset = idx * num_cols;
auto ret = this->SliceOffset(offset);
ret.num_rows = 1;
return ret;
}
template <typename T = float>
XGBOOST_DEVICE T GetElement(size_t idx) const {
SPAN_CHECK(idx < num_cols * num_rows);
if (type[1] == 'f' && type[2] == '4') {
switch (type) {
case kF4:
return reinterpret_cast<float*>(data)[idx];
} else if (type[1] == 'f' && type[2] == '8') {
case kF8:
return reinterpret_cast<double*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '1') {
case kI1:
return reinterpret_cast<int8_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '2') {
case kI2:
return reinterpret_cast<int16_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '4') {
case kI4:
return reinterpret_cast<int32_t*>(data)[idx];
} else if (type[1] == 'i' && type[2] == '8') {
case kI8:
return reinterpret_cast<int64_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '1') {
case kU1:
return reinterpret_cast<uint8_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '2') {
case kU2:
return reinterpret_cast<uint16_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '4') {
case kU4:
return reinterpret_cast<uint32_t*>(data)[idx];
} else if (type[1] == 'u' && type[2] == '8') {
case kU8:
return reinterpret_cast<uint64_t*>(data)[idx];
} else {
SPAN_CHECK(false);
return 0;
}
SPAN_CHECK(false);
return reinterpret_cast<float*>(data)[idx];
}
RBitField8 valid;
bst_row_t num_rows;
bst_feature_t num_cols;
void* data;
char type[3];
Type type;
};
} // namespace xgboost