Fix inplace predict missing value. (#6787)
This commit is contained in:
parent
5c87c2bba8
commit
a59c7323b4
@ -255,7 +255,7 @@ XGB_DLL int XGDMatrixCreateFromCSR(char const *indptr,
|
|||||||
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
data::CSRArrayAdapter adapter(StringView{indptr}, StringView{indices},
|
||||||
StringView{data}, ncol);
|
StringView{data}, ncol);
|
||||||
auto config = Json::Load(StringView{c_json_config});
|
auto config = Json::Load(StringView{c_json_config});
|
||||||
float missing = get<Number const>(config["missing"]);
|
float missing = GetMissing(config);
|
||||||
auto nthread = get<Integer const>(config["nthread"]);
|
auto nthread = get<Integer const>(config["nthread"]);
|
||||||
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
*out = new std::shared_ptr<DMatrix>(DMatrix::Create(&adapter, missing, nthread));
|
||||||
API_END();
|
API_END();
|
||||||
@ -683,8 +683,8 @@ void InplacePredictImpl(std::shared_ptr<T> x, std::shared_ptr<DMatrix> p_m,
|
|||||||
|
|
||||||
HostDeviceVector<float>* p_predt { nullptr };
|
HostDeviceVector<float>* p_predt { nullptr };
|
||||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
float missing = GetMissing(config);
|
||||||
&p_predt,
|
learner->InplacePredict(x, p_m, type, missing, &p_predt,
|
||||||
get<Integer const>(config["iteration_begin"]),
|
get<Integer const>(config["iteration_begin"]),
|
||||||
get<Integer const>(config["iteration_end"]));
|
get<Integer const>(config["iteration_end"]));
|
||||||
CHECK(p_predt);
|
CHECK(p_predt);
|
||||||
|
|||||||
@ -48,8 +48,9 @@ int InplacePreidctCuda(BoosterHandle handle, char const *c_json_strs,
|
|||||||
auto x = std::make_shared<T>(json_str);
|
auto x = std::make_shared<T>(json_str);
|
||||||
HostDeviceVector<float> *p_predt{nullptr};
|
HostDeviceVector<float> *p_predt{nullptr};
|
||||||
auto type = PredictionType(get<Integer const>(config["type"]));
|
auto type = PredictionType(get<Integer const>(config["type"]));
|
||||||
learner->InplacePredict(x, p_m, type, get<Number const>(config["missing"]),
|
float missing = GetMissing(config);
|
||||||
&p_predt,
|
|
||||||
|
learner->InplacePredict(x, p_m, type, missing, &p_predt,
|
||||||
get<Integer const>(config["iteration_begin"]),
|
get<Integer const>(config["iteration_begin"]),
|
||||||
get<Integer const>(config["iteration_end"]));
|
get<Integer const>(config["iteration_end"]));
|
||||||
CHECK(p_predt);
|
CHECK(p_predt);
|
||||||
|
|||||||
@ -11,6 +11,9 @@
|
|||||||
#include "xgboost/logging.h"
|
#include "xgboost/logging.h"
|
||||||
#include "xgboost/json.h"
|
#include "xgboost/json.h"
|
||||||
#include "xgboost/learner.h"
|
#include "xgboost/learner.h"
|
||||||
|
#include "xgboost/c_api.h"
|
||||||
|
|
||||||
|
#include "c_api_error.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
/* \brief Determine the output shape of prediction.
|
/* \brief Determine the output shape of prediction.
|
||||||
@ -141,5 +144,19 @@ inline uint32_t GetIterationFromTreeLimit(uint32_t ntree_limit, Learner *learner
|
|||||||
}
|
}
|
||||||
return ntree_limit;
|
return ntree_limit;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline float GetMissing(Json const &config) {
|
||||||
|
float missing;
|
||||||
|
auto const& j_missing = config["missing"];
|
||||||
|
if (IsA<Number const>(j_missing)) {
|
||||||
|
missing = get<Number const>(j_missing);
|
||||||
|
} else if (IsA<Integer const>(j_missing)) {
|
||||||
|
missing = get<Integer const>(j_missing);
|
||||||
|
} else {
|
||||||
|
missing = nan("");
|
||||||
|
LOG(FATAL) << "Invalid missing value: " << j_missing;
|
||||||
|
}
|
||||||
|
return missing;
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
#endif // XGBOOST_C_API_C_API_UTILS_H_
|
||||||
|
|||||||
@ -16,9 +16,14 @@ namespace xgboost {
|
|||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
struct IsValidFunctor : public thrust::unary_function<Entry, bool> {
|
||||||
explicit IsValidFunctor(float missing) : missing(missing) {}
|
|
||||||
|
|
||||||
float missing;
|
float missing;
|
||||||
|
|
||||||
|
XGBOOST_DEVICE explicit IsValidFunctor(float missing) : missing(missing) {}
|
||||||
|
|
||||||
|
__device__ bool operator()(float value) const {
|
||||||
|
return !(common::CheckNAN(value) || value == missing);
|
||||||
|
}
|
||||||
|
|
||||||
__device__ bool operator()(const data::COOTuple& e) const {
|
__device__ bool operator()(const data::COOTuple& e) const {
|
||||||
if (common::CheckNAN(e.value) || e.value == missing) {
|
if (common::CheckNAN(e.value) || e.value == missing) {
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -76,7 +76,7 @@ struct SparsePageLoader {
|
|||||||
size_t entry_start;
|
size_t entry_start;
|
||||||
|
|
||||||
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
|
||||||
bst_row_t num_rows, size_t entry_start)
|
bst_row_t num_rows, size_t entry_start, float)
|
||||||
: use_shared(use_shared),
|
: use_shared(use_shared),
|
||||||
data(data),
|
data(data),
|
||||||
entry_start(entry_start) {
|
entry_start(entry_start) {
|
||||||
@ -111,7 +111,7 @@ struct SparsePageLoader {
|
|||||||
struct EllpackLoader {
|
struct EllpackLoader {
|
||||||
EllpackDeviceAccessor const& matrix;
|
EllpackDeviceAccessor const& matrix;
|
||||||
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool,
|
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool,
|
||||||
bst_feature_t, bst_row_t, size_t)
|
bst_feature_t, bst_row_t, size_t, float)
|
||||||
: matrix{m} {}
|
: matrix{m} {}
|
||||||
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
|
||||||
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
auto gidx = matrix.GetBinIndex(ridx, fidx);
|
||||||
@ -133,15 +133,17 @@ struct DeviceAdapterLoader {
|
|||||||
bst_feature_t columns;
|
bst_feature_t columns;
|
||||||
float* smem;
|
float* smem;
|
||||||
bool use_shared;
|
bool use_shared;
|
||||||
|
data::IsValidFunctor is_valid;
|
||||||
|
|
||||||
using BatchT = Batch;
|
using BatchT = Batch;
|
||||||
|
|
||||||
XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared,
|
||||||
bst_feature_t num_features, bst_row_t num_rows,
|
bst_feature_t num_features, bst_row_t num_rows,
|
||||||
size_t entry_start) :
|
size_t entry_start, float missing) :
|
||||||
batch{batch},
|
batch{batch},
|
||||||
columns{num_features},
|
columns{num_features},
|
||||||
use_shared{use_shared} {
|
use_shared{use_shared},
|
||||||
|
is_valid{missing} {
|
||||||
extern __shared__ float _smem[];
|
extern __shared__ float _smem[];
|
||||||
smem = _smem;
|
smem = _smem;
|
||||||
if (use_shared) {
|
if (use_shared) {
|
||||||
@ -153,7 +155,10 @@ struct DeviceAdapterLoader {
|
|||||||
auto beg = global_idx * columns;
|
auto beg = global_idx * columns;
|
||||||
auto end = (global_idx + 1) * columns;
|
auto end = (global_idx + 1) * columns;
|
||||||
for (size_t i = beg; i < end; ++i) {
|
for (size_t i = beg; i < end; ++i) {
|
||||||
smem[threadIdx.x * num_features + (i - beg)] = batch.GetElement(i).value;
|
auto value = batch.GetElement(i).value;
|
||||||
|
if (is_valid(value)) {
|
||||||
|
smem[threadIdx.x * num_features + (i - beg)] = value;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -164,7 +169,12 @@ struct DeviceAdapterLoader {
|
|||||||
if (use_shared) {
|
if (use_shared) {
|
||||||
return smem[threadIdx.x * columns + fidx];
|
return smem[threadIdx.x * columns + fidx];
|
||||||
}
|
}
|
||||||
return batch.GetElement(ridx * columns + fidx).value;
|
auto value = batch.GetElement(ridx * columns + fidx).value;
|
||||||
|
if (is_valid(value)) {
|
||||||
|
return value;
|
||||||
|
} else {
|
||||||
|
return nan("");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -209,7 +219,7 @@ __device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree,
|
|||||||
while (!n.IsLeaf()) {
|
while (!n.IsLeaf()) {
|
||||||
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
float fvalue = loader.GetElement(ridx, n.SplitIndex());
|
||||||
// Missing value
|
// Missing value
|
||||||
if (isnan(fvalue)) {
|
if (common::CheckNAN(fvalue)) {
|
||||||
nidx = n.DefaultChild();
|
nidx = n.DefaultChild();
|
||||||
n = tree[nidx];
|
n = tree[nidx];
|
||||||
} else {
|
} else {
|
||||||
@ -231,12 +241,13 @@ __global__ void PredictLeafKernel(Data data,
|
|||||||
common::Span<float> d_out_predictions,
|
common::Span<float> d_out_predictions,
|
||||||
common::Span<size_t const> d_tree_segments,
|
common::Span<size_t const> d_tree_segments,
|
||||||
size_t tree_begin, size_t tree_end, size_t num_features,
|
size_t tree_begin, size_t tree_end, size_t num_features,
|
||||||
size_t num_rows, size_t entry_start, bool use_shared) {
|
size_t num_rows, size_t entry_start, bool use_shared,
|
||||||
|
float missing) {
|
||||||
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (ridx >= num_rows) {
|
if (ridx >= num_rows) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
|
||||||
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]];
|
||||||
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
auto leaf = GetLeafIndex(ridx, d_tree, loader);
|
||||||
@ -255,9 +266,9 @@ PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
|
|||||||
common::Span<RegTree::Segment const> d_cat_node_segments,
|
common::Span<RegTree::Segment const> d_cat_node_segments,
|
||||||
common::Span<uint32_t const> d_categories, size_t tree_begin,
|
common::Span<uint32_t const> d_categories, size_t tree_begin,
|
||||||
size_t tree_end, size_t num_features, size_t num_rows,
|
size_t tree_end, size_t num_features, size_t num_rows,
|
||||||
size_t entry_start, bool use_shared, int num_group) {
|
size_t entry_start, bool use_shared, int num_group, float missing) {
|
||||||
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
Loader loader(data, use_shared, num_features, num_rows, entry_start);
|
Loader loader(data, use_shared, num_features, num_rows, entry_start, missing);
|
||||||
if (global_idx >= num_rows) return;
|
if (global_idx >= num_rows) return;
|
||||||
if (num_group == 1) {
|
if (num_group == 1) {
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
@ -527,7 +538,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
model.categories_tree_segments.ConstDeviceSpan(),
|
model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
model.categories_node_segments.ConstDeviceSpan(),
|
model.categories_node_segments.ConstDeviceSpan(),
|
||||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||||
num_features, num_rows, entry_start, use_shared, model.num_group);
|
num_features, num_rows, entry_start, use_shared, model.num_group, nan(""));
|
||||||
}
|
}
|
||||||
void PredictInternal(EllpackDeviceAccessor const& batch,
|
void PredictInternal(EllpackDeviceAccessor const& batch,
|
||||||
DeviceModel const& model,
|
DeviceModel const& model,
|
||||||
@ -549,7 +560,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
model.categories_node_segments.ConstDeviceSpan(),
|
model.categories_node_segments.ConstDeviceSpan(),
|
||||||
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
|
||||||
batch.NumFeatures(), num_rows, entry_start, use_shared,
|
batch.NumFeatures(), num_rows, entry_start, use_shared,
|
||||||
model.num_group);
|
model.num_group, nan(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
|
||||||
@ -607,7 +618,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
|
|
||||||
template <typename Adapter, typename Loader>
|
template <typename Adapter, typename Loader>
|
||||||
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
void DispatchedInplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||||
const gbm::GBTreeModel &model, float,
|
const gbm::GBTreeModel &model, float missing,
|
||||||
PredictionCacheEntry *out_preds,
|
PredictionCacheEntry *out_preds,
|
||||||
uint32_t tree_begin, uint32_t tree_end) const {
|
uint32_t tree_begin, uint32_t tree_end) const {
|
||||||
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
uint32_t const output_groups = model.learner_model_param->num_output_group;
|
||||||
@ -648,7 +659,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
d_model.categories_tree_segments.ConstDeviceSpan(),
|
d_model.categories_tree_segments.ConstDeviceSpan(),
|
||||||
d_model.categories_node_segments.ConstDeviceSpan(),
|
d_model.categories_node_segments.ConstDeviceSpan(),
|
||||||
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
|
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
|
||||||
m->NumRows(), entry_start, use_shared, output_groups);
|
m->NumRows(), entry_start, use_shared, output_groups, missing);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
|
||||||
@ -836,7 +847,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
d_model.tree_segments.ConstDeviceSpan(),
|
d_model.tree_segments.ConstDeviceSpan(),
|
||||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||||
entry_start, use_shared);
|
entry_start, use_shared, nan(""));
|
||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -852,7 +863,7 @@ class GPUPredictor : public xgboost::Predictor {
|
|||||||
predictions->DeviceSpan().subspan(batch_offset),
|
predictions->DeviceSpan().subspan(batch_offset),
|
||||||
d_model.tree_segments.ConstDeviceSpan(),
|
d_model.tree_segments.ConstDeviceSpan(),
|
||||||
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
|
||||||
entry_start, use_shared);
|
entry_start, use_shared, nan(""));
|
||||||
batch_offset += batch.Size();
|
batch_offset += batch.Size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -374,6 +374,11 @@ TEST(Json, AssigningNumber) {
|
|||||||
value = 15; // NOLINT
|
value = 15; // NOLINT
|
||||||
ASSERT_EQ(get<Number>(json), 4);
|
ASSERT_EQ(get<Number>(json), 4);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
Json value {Number(std::numeric_limits<float>::quiet_NaN())};
|
||||||
|
ASSERT_TRUE(IsA<Number>(value));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(Json, AssigningString) {
|
TEST(Json, AssigningString) {
|
||||||
|
|||||||
@ -154,17 +154,22 @@ class TestGPUPredict:
|
|||||||
cp.cuda.runtime.setDevice(0)
|
cp.cuda.runtime.setDevice(0)
|
||||||
rows = 1000
|
rows = 1000
|
||||||
cols = 10
|
cols = 10
|
||||||
|
missing = 11 # set to integer for testing
|
||||||
|
|
||||||
cp_rng = cp.random.RandomState(1994)
|
cp_rng = cp.random.RandomState(1994)
|
||||||
cp.random.set_random_state(cp_rng)
|
cp.random.set_random_state(cp_rng)
|
||||||
|
|
||||||
X = cp.random.randn(rows, cols)
|
X = cp.random.randn(rows, cols)
|
||||||
|
missing_idx = [i for i in range(0, cols, 4)]
|
||||||
|
X[:, missing_idx] = missing # set to be missing
|
||||||
y = cp.random.randn(rows)
|
y = cp.random.randn(rows)
|
||||||
|
|
||||||
dtrain = xgb.DMatrix(X, y)
|
dtrain = xgb.DMatrix(X, y)
|
||||||
|
|
||||||
booster = xgb.train({'tree_method': 'gpu_hist'},
|
booster = xgb.train({'tree_method': 'gpu_hist'}, dtrain, num_boost_round=10)
|
||||||
dtrain, num_boost_round=10)
|
|
||||||
test = xgb.DMatrix(X[:10, ...])
|
test = xgb.DMatrix(X[:10, ...], missing=missing)
|
||||||
predt_from_array = booster.inplace_predict(X[:10, ...])
|
predt_from_array = booster.inplace_predict(X[:10, ...], missing=missing)
|
||||||
predt_from_dmatrix = booster.predict(test)
|
predt_from_dmatrix = booster.predict(test)
|
||||||
|
|
||||||
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
|
cp.testing.assert_allclose(predt_from_array, predt_from_dmatrix)
|
||||||
@ -185,6 +190,20 @@ class TestGPUPredict:
|
|||||||
base_margin = cp_rng.randn(rows)
|
base_margin = cp_rng.randn(rows)
|
||||||
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
|
self.run_inplace_base_margin(booster, dtrain, X, base_margin)
|
||||||
|
|
||||||
|
# Create a wide dataset
|
||||||
|
X = cp_rng.randn(100, 10000)
|
||||||
|
y = cp_rng.randn(100)
|
||||||
|
|
||||||
|
missing_idx = [i for i in range(0, X.shape[1], 16)]
|
||||||
|
X[:, missing_idx] = missing
|
||||||
|
reg = xgb.XGBRegressor(tree_method="gpu_hist", n_estimators=8, missing=missing)
|
||||||
|
reg.fit(X, y)
|
||||||
|
|
||||||
|
gpu_predt = reg.predict(X)
|
||||||
|
reg.set_params(predictor="cpu_predictor")
|
||||||
|
cpu_predt = reg.predict(X)
|
||||||
|
np.testing.assert_allclose(gpu_predt, cpu_predt, atol=1e-6)
|
||||||
|
|
||||||
@pytest.mark.skipif(**tm.no_cudf())
|
@pytest.mark.skipif(**tm.no_cudf())
|
||||||
def test_inplace_predict_cudf(self):
|
def test_inplace_predict_cudf(self):
|
||||||
import cupy as cp
|
import cupy as cp
|
||||||
|
|||||||
@ -103,31 +103,37 @@ class TestInplacePredict:
|
|||||||
'''Tests for running inplace prediction'''
|
'''Tests for running inplace prediction'''
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_class(cls):
|
def setup_class(cls):
|
||||||
cls.rows = 100
|
cls.rows = 1000
|
||||||
cls.cols = 10
|
cls.cols = 10
|
||||||
|
|
||||||
|
cls.missing = 11 # set to integer for testing
|
||||||
|
|
||||||
cls.rng = np.random.RandomState(1994)
|
cls.rng = np.random.RandomState(1994)
|
||||||
|
|
||||||
cls.X = cls.rng.randn(cls.rows, cls.cols)
|
cls.X = cls.rng.randn(cls.rows, cls.cols)
|
||||||
|
missing_idx = [i for i in range(0, cls.cols, 4)]
|
||||||
|
cls.X[:, missing_idx] = cls.missing # set to be missing
|
||||||
|
|
||||||
cls.y = cls.rng.randn(cls.rows)
|
cls.y = cls.rng.randn(cls.rows)
|
||||||
|
|
||||||
dtrain = xgb.DMatrix(cls.X, cls.y)
|
dtrain = xgb.DMatrix(cls.X, cls.y)
|
||||||
|
cls.test = xgb.DMatrix(cls.X[:10, ...], missing=cls.missing)
|
||||||
|
|
||||||
cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10)
|
cls.booster = xgb.train({'tree_method': 'hist'}, dtrain, num_boost_round=10)
|
||||||
|
|
||||||
cls.test = xgb.DMatrix(cls.X[:10, ...])
|
|
||||||
|
|
||||||
def test_predict(self):
|
def test_predict(self):
|
||||||
booster = self.booster
|
booster = self.booster
|
||||||
X = self.X
|
X = self.X
|
||||||
test = self.test
|
test = self.test
|
||||||
|
|
||||||
predt_from_array = booster.inplace_predict(X[:10, ...])
|
predt_from_array = booster.inplace_predict(X[:10, ...], missing=self.missing)
|
||||||
predt_from_dmatrix = booster.predict(test)
|
predt_from_dmatrix = booster.predict(test)
|
||||||
|
|
||||||
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
||||||
|
|
||||||
predt_from_array = booster.inplace_predict(X[:10, ...], iteration_range=(0, 4))
|
predt_from_array = booster.inplace_predict(
|
||||||
|
X[:10, ...], iteration_range=(0, 4), missing=self.missing
|
||||||
|
)
|
||||||
predt_from_dmatrix = booster.predict(test, ntree_limit=4)
|
predt_from_dmatrix = booster.predict(test, ntree_limit=4)
|
||||||
|
|
||||||
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
np.testing.assert_allclose(predt_from_dmatrix, predt_from_array)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user