SYCL inference optimization (#9876)
--------- Co-authored-by: Dmitry Razdoburdin <>
This commit is contained in:
parent
1c6e031c75
commit
2a6ab2547d
@ -66,13 +66,13 @@ class USMVector {
|
|||||||
public:
|
public:
|
||||||
USMVector() : size_(0), capacity_(0), data_(nullptr) {}
|
USMVector() : size_(0), capacity_(0), data_(nullptr) {}
|
||||||
|
|
||||||
USMVector(::sycl::queue& qu, size_t size) : size_(size), capacity_(size) {
|
USMVector(::sycl::queue* qu, size_t size) : size_(size), capacity_(size) {
|
||||||
data_ = allocate_memory_(qu, size_);
|
data_ = allocate_memory_(qu, size_);
|
||||||
}
|
}
|
||||||
|
|
||||||
USMVector(::sycl::queue& qu, size_t size, T v) : size_(size), capacity_(size) {
|
USMVector(::sycl::queue* qu, size_t size, T v) : size_(size), capacity_(size) {
|
||||||
data_ = allocate_memory_(qu, size_);
|
data_ = allocate_memory_(qu, size_);
|
||||||
qu.fill(data_.get(), v, size_).wait();
|
qu->fill(data_.get(), v, size_).wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
USMVector(::sycl::queue* qu, const std::vector<T> &vec) {
|
USMVector(::sycl::queue* qu, const std::vector<T> &vec) {
|
||||||
@ -147,25 +147,22 @@ class USMVector {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
::sycl::event ResizeAsync(::sycl::queue* qu, size_t size_new, T v) {
|
void Resize(::sycl::queue* qu, size_t size_new, T v, ::sycl::event* event) {
|
||||||
if (size_new <= size_) {
|
if (size_new <= size_) {
|
||||||
size_ = size_new;
|
size_ = size_new;
|
||||||
return ::sycl::event();
|
|
||||||
} else if (size_new <= capacity_) {
|
} else if (size_new <= capacity_) {
|
||||||
auto event = qu->fill(data_.get() + size_, v, size_new - size_);
|
auto event = qu->fill(data_.get() + size_, v, size_new - size_);
|
||||||
size_ = size_new;
|
size_ = size_new;
|
||||||
return event;
|
|
||||||
} else {
|
} else {
|
||||||
size_t size_old = size_;
|
size_t size_old = size_;
|
||||||
auto data_old = data_;
|
auto data_old = data_;
|
||||||
size_ = size_new;
|
size_ = size_new;
|
||||||
capacity_ = size_new;
|
capacity_ = size_new;
|
||||||
data_ = allocate_memory_(qu, size_);
|
data_ = allocate_memory_(qu, size_);
|
||||||
::sycl::event event;
|
|
||||||
if (size_old > 0) {
|
if (size_old > 0) {
|
||||||
event = qu->memcpy(data_.get(), data_old.get(), sizeof(T) * size_old);
|
*event = qu->memcpy(data_.get(), data_old.get(), sizeof(T) * size_old, *event);
|
||||||
}
|
}
|
||||||
return qu->fill(data_.get() + size_old, v, size_new - size_old, event);
|
*event = qu->fill(data_.get() + size_old, v, size_new - size_old, *event);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -210,7 +207,7 @@ struct DeviceMatrix {
|
|||||||
DMatrix* p_mat; // Pointer to the original matrix on the host
|
DMatrix* p_mat; // Pointer to the original matrix on the host
|
||||||
::sycl::queue qu_;
|
::sycl::queue qu_;
|
||||||
USMVector<size_t> row_ptr;
|
USMVector<size_t> row_ptr;
|
||||||
USMVector<Entry> data;
|
USMVector<Entry, MemoryType::on_device> data;
|
||||||
size_t total_offset;
|
size_t total_offset;
|
||||||
|
|
||||||
DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) {
|
DeviceMatrix(::sycl::queue qu, DMatrix* dmat) : p_mat(dmat), qu_(qu) {
|
||||||
@ -238,8 +235,9 @@ struct DeviceMatrix {
|
|||||||
for (size_t i = 0; i < batch_size; i++)
|
for (size_t i = 0; i < batch_size; i++)
|
||||||
row_ptr[i + batch.base_rowid] += batch.base_rowid;
|
row_ptr[i + batch.base_rowid] += batch.base_rowid;
|
||||||
}
|
}
|
||||||
std::copy(data_vec.data(), data_vec.data() + offset_vec[batch_size],
|
qu.memcpy(data.Data() + data_offset,
|
||||||
data.Data() + data_offset);
|
data_vec.data(),
|
||||||
|
offset_vec[batch_size] * sizeof(Entry)).wait();
|
||||||
data_offset += offset_vec[batch_size];
|
data_offset += offset_vec[batch_size];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,6 +20,7 @@
|
|||||||
#include "xgboost/tree_model.h"
|
#include "xgboost/tree_model.h"
|
||||||
#include "xgboost/predictor.h"
|
#include "xgboost/predictor.h"
|
||||||
#include "xgboost/tree_updater.h"
|
#include "xgboost/tree_updater.h"
|
||||||
|
#include "../../../src/common/timer.h"
|
||||||
|
|
||||||
#pragma GCC diagnostic push
|
#pragma GCC diagnostic push
|
||||||
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
|
||||||
@ -36,36 +37,37 @@ namespace predictor {
|
|||||||
|
|
||||||
DMLC_REGISTRY_FILE_TAG(predictor_sycl);
|
DMLC_REGISTRY_FILE_TAG(predictor_sycl);
|
||||||
|
|
||||||
/* Wrapper for descriptor of a tree node */
|
union NodeValue {
|
||||||
struct DeviceNode {
|
float leaf_weight;
|
||||||
DeviceNode()
|
float fvalue;
|
||||||
: fidx(-1), left_child_idx(-1), right_child_idx(-1) {}
|
};
|
||||||
|
|
||||||
union NodeValue {
|
|
||||||
float leaf_weight;
|
|
||||||
float fvalue;
|
|
||||||
};
|
|
||||||
|
|
||||||
|
class Node {
|
||||||
int fidx;
|
int fidx;
|
||||||
int left_child_idx;
|
int left_child_idx;
|
||||||
int right_child_idx;
|
int right_child_idx;
|
||||||
NodeValue val;
|
NodeValue val;
|
||||||
|
|
||||||
explicit DeviceNode(const RegTree::Node& n) {
|
public:
|
||||||
this->left_child_idx = n.LeftChild();
|
explicit Node(const RegTree::Node& n) {
|
||||||
this->right_child_idx = n.RightChild();
|
left_child_idx = n.LeftChild();
|
||||||
this->fidx = n.SplitIndex();
|
right_child_idx = n.RightChild();
|
||||||
|
fidx = n.SplitIndex();
|
||||||
if (n.DefaultLeft()) {
|
if (n.DefaultLeft()) {
|
||||||
fidx |= (1U << 31);
|
fidx |= (1U << 31);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n.IsLeaf()) {
|
if (n.IsLeaf()) {
|
||||||
this->val.leaf_weight = n.LeafValue();
|
val.leaf_weight = n.LeafValue();
|
||||||
} else {
|
} else {
|
||||||
this->val.fvalue = n.SplitCond();
|
val.fvalue = n.SplitCond();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int LeftChildIdx() const {return left_child_idx; }
|
||||||
|
|
||||||
|
int RightChildIdx() const {return right_child_idx; }
|
||||||
|
|
||||||
bool IsLeaf() const { return left_child_idx == -1; }
|
bool IsLeaf() const { return left_child_idx == -1; }
|
||||||
|
|
||||||
int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
int GetFidx() const { return fidx & ((1U << 31) - 1U); }
|
||||||
@ -74,9 +76,9 @@ struct DeviceNode {
|
|||||||
|
|
||||||
int MissingIdx() const {
|
int MissingIdx() const {
|
||||||
if (MissingLeft()) {
|
if (MissingLeft()) {
|
||||||
return this->left_child_idx;
|
return left_child_idx;
|
||||||
} else {
|
} else {
|
||||||
return this->right_child_idx;
|
return right_child_idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,105 +87,79 @@ struct DeviceNode {
|
|||||||
float GetWeight() const { return val.leaf_weight; }
|
float GetWeight() const { return val.leaf_weight; }
|
||||||
};
|
};
|
||||||
|
|
||||||
/* SYCL implementation of a device model,
|
|
||||||
* storing tree structure in USM buffers to provide access from device kernels
|
|
||||||
*/
|
|
||||||
class DeviceModel {
|
class DeviceModel {
|
||||||
public:
|
public:
|
||||||
::sycl::queue qu_;
|
USMVector<Node> nodes;
|
||||||
USMVector<DeviceNode> nodes_;
|
USMVector<size_t> first_node_position;
|
||||||
USMVector<size_t> tree_segments_;
|
USMVector<int> tree_group;
|
||||||
USMVector<int> tree_group_;
|
size_t tree_beg;
|
||||||
size_t tree_beg_;
|
size_t tree_end;
|
||||||
size_t tree_end_;
|
int num_group;
|
||||||
int num_group_;
|
|
||||||
|
|
||||||
DeviceModel() {}
|
void Init(::sycl::queue* qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
||||||
|
int n_nodes = 0;
|
||||||
~DeviceModel() {}
|
first_node_position.Resize(qu, (tree_end - tree_begin) + 1);
|
||||||
|
first_node_position[0] = n_nodes;
|
||||||
void Init(::sycl::queue qu, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) {
|
|
||||||
qu_ = qu;
|
|
||||||
|
|
||||||
tree_segments_.Resize(&qu_, (tree_end - tree_begin) + 1);
|
|
||||||
int sum = 0;
|
|
||||||
tree_segments_[0] = sum;
|
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
if (model.trees[tree_idx]->HasCategoricalSplit()) {
|
if (model.trees[tree_idx]->HasCategoricalSplit()) {
|
||||||
LOG(FATAL) << "Categorical features are not yet supported by sycl";
|
LOG(FATAL) << "Categorical features are not yet supported by sycl";
|
||||||
}
|
}
|
||||||
sum += model.trees[tree_idx]->GetNodes().size();
|
n_nodes += model.trees[tree_idx]->GetNodes().size();
|
||||||
tree_segments_[tree_idx - tree_begin + 1] = sum;
|
first_node_position[tree_idx - tree_begin + 1] = n_nodes;
|
||||||
}
|
}
|
||||||
|
|
||||||
nodes_.Resize(&qu_, sum);
|
nodes.Resize(qu, n_nodes);
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
auto& src_nodes = model.trees[tree_idx]->GetNodes();
|
||||||
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++)
|
size_t n_nodes_shift = first_node_position[tree_idx - tree_begin];
|
||||||
nodes_[node_idx + tree_segments_[tree_idx - tree_begin]] =
|
for (size_t node_idx = 0; node_idx < src_nodes.size(); node_idx++) {
|
||||||
static_cast<DeviceNode>(src_nodes[node_idx]);
|
nodes[node_idx + n_nodes_shift] = static_cast<Node>(src_nodes[node_idx]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tree_group_.Resize(&qu_, model.tree_info.size());
|
tree_group.Resize(qu, model.tree_info.size());
|
||||||
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
|
for (size_t tree_idx = 0; tree_idx < model.tree_info.size(); tree_idx++)
|
||||||
tree_group_[tree_idx] = model.tree_info[tree_idx];
|
tree_group[tree_idx] = model.tree_info[tree_idx];
|
||||||
|
|
||||||
tree_beg_ = tree_begin;
|
tree_beg = tree_begin;
|
||||||
tree_end_ = tree_end;
|
tree_end = tree_end;
|
||||||
num_group_ = model.learner_model_param->num_output_group;
|
num_group = model.learner_model_param->num_output_group;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
float GetFvalue(int ridx, int fidx, Entry* data, size_t* row_ptr, bool* is_missing) {
|
float GetLeafWeight(const Node* nodes, const float* fval_buff, const uint8_t* miss_buff) {
|
||||||
// Binary search
|
const Node* node = nodes;
|
||||||
auto begin_ptr = data + row_ptr[ridx];
|
while (!node->IsLeaf()) {
|
||||||
auto end_ptr = data + row_ptr[ridx + 1];
|
if (miss_buff[node->GetFidx()] == 1) {
|
||||||
Entry* previous_middle = nullptr;
|
node = nodes + node->MissingIdx();
|
||||||
while (end_ptr != begin_ptr) {
|
|
||||||
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
|
|
||||||
if (middle == previous_middle) {
|
|
||||||
break;
|
|
||||||
} else {
|
} else {
|
||||||
previous_middle = middle;
|
const float fvalue = fval_buff[node->GetFidx()];
|
||||||
}
|
if (fvalue < node->GetFvalue()) {
|
||||||
|
node = nodes + node->LeftChildIdx();
|
||||||
if (middle->index == fidx) {
|
|
||||||
*is_missing = false;
|
|
||||||
return middle->fvalue;
|
|
||||||
} else if (middle->index < fidx) {
|
|
||||||
begin_ptr = middle;
|
|
||||||
} else {
|
|
||||||
end_ptr = middle;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*is_missing = true;
|
|
||||||
return 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
float GetLeafWeight(int ridx, const DeviceNode* tree, Entry* data, size_t* row_ptr) {
|
|
||||||
DeviceNode n = tree[0];
|
|
||||||
int node_id = 0;
|
|
||||||
bool is_missing;
|
|
||||||
while (!n.IsLeaf()) {
|
|
||||||
float fvalue = GetFvalue(ridx, n.GetFidx(), data, row_ptr, &is_missing);
|
|
||||||
// Missing value
|
|
||||||
if (is_missing) {
|
|
||||||
n = tree[n.MissingIdx()];
|
|
||||||
} else {
|
|
||||||
if (fvalue < n.GetFvalue()) {
|
|
||||||
node_id = n.left_child_idx;
|
|
||||||
n = tree[n.left_child_idx];
|
|
||||||
} else {
|
} else {
|
||||||
node_id = n.right_child_idx;
|
node = nodes + node->RightChildIdx();
|
||||||
n = tree[n.right_child_idx];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return n.GetWeight();
|
return node->GetWeight();
|
||||||
}
|
}
|
||||||
|
|
||||||
void DevicePredictInternal(::sycl::queue qu,
|
float GetLeafWeight(const Node* nodes, const float* fval_buff) {
|
||||||
sycl::DeviceMatrix* dmat,
|
const Node* node = nodes;
|
||||||
|
while (!node->IsLeaf()) {
|
||||||
|
const float fvalue = fval_buff[node->GetFidx()];
|
||||||
|
if (fvalue < node->GetFvalue()) {
|
||||||
|
node = nodes + node->LeftChildIdx();
|
||||||
|
} else {
|
||||||
|
node = nodes + node->RightChildIdx();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node->GetWeight();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool any_missing>
|
||||||
|
void DevicePredictInternal(::sycl::queue* qu,
|
||||||
|
const sycl::DeviceMatrix& dmat,
|
||||||
HostDeviceVector<float>* out_preds,
|
HostDeviceVector<float>* out_preds,
|
||||||
const gbm::GBTreeModel& model,
|
const gbm::GBTreeModel& model,
|
||||||
size_t tree_begin,
|
size_t tree_begin,
|
||||||
@ -194,43 +170,75 @@ void DevicePredictInternal(::sycl::queue qu,
|
|||||||
DeviceModel device_model;
|
DeviceModel device_model;
|
||||||
device_model.Init(qu, model, tree_begin, tree_end);
|
device_model.Init(qu, model, tree_begin, tree_end);
|
||||||
|
|
||||||
auto& out_preds_vec = out_preds->HostVector();
|
const Node* nodes = device_model.nodes.DataConst();
|
||||||
|
const size_t* first_node_position = device_model.first_node_position.DataConst();
|
||||||
DeviceNode* nodes = device_model.nodes_.Data();
|
const int* tree_group = device_model.tree_group.DataConst();
|
||||||
::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
|
const size_t* row_ptr = dmat.row_ptr.DataConst();
|
||||||
size_t* tree_segments = device_model.tree_segments_.Data();
|
const Entry* data = dmat.data.DataConst();
|
||||||
int* tree_group = device_model.tree_group_.Data();
|
int num_features = dmat.p_mat->Info().num_col_;
|
||||||
size_t* row_ptr = dmat->row_ptr.Data();
|
int num_rows = dmat.row_ptr.Size() - 1;
|
||||||
Entry* data = dmat->data.Data();
|
|
||||||
int num_features = dmat->p_mat->Info().num_col_;
|
|
||||||
int num_rows = dmat->row_ptr.Size() - 1;
|
|
||||||
int num_group = model.learner_model_param->num_output_group;
|
int num_group = model.learner_model_param->num_output_group;
|
||||||
|
|
||||||
qu.submit([&](::sycl::handler& cgh) {
|
USMVector<float, MemoryType::on_device> fval_buff(qu, num_features * num_rows);
|
||||||
|
USMVector<uint8_t, MemoryType::on_device> miss_buff;
|
||||||
|
auto* fval_buff_ptr = fval_buff.Data();
|
||||||
|
|
||||||
|
std::vector<::sycl::event> events(1);
|
||||||
|
if constexpr (any_missing) {
|
||||||
|
miss_buff.Resize(qu, num_features * num_rows, 1, &events[0]);
|
||||||
|
}
|
||||||
|
auto* miss_buff_ptr = miss_buff.Data();
|
||||||
|
|
||||||
|
auto& out_preds_vec = out_preds->HostVector();
|
||||||
|
::sycl::buffer<float, 1> out_preds_buf(out_preds_vec.data(), out_preds_vec.size());
|
||||||
|
events[0] = qu->submit([&](::sycl::handler& cgh) {
|
||||||
|
cgh.depends_on(events[0]);
|
||||||
auto out_predictions = out_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh);
|
auto out_predictions = out_preds_buf.template get_access<::sycl::access::mode::read_write>(cgh);
|
||||||
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::id<1> pid) {
|
cgh.parallel_for<>(::sycl::range<1>(num_rows), [=](::sycl::id<1> pid) {
|
||||||
int global_idx = pid[0];
|
int row_idx = pid[0];
|
||||||
if (global_idx >= num_rows) return;
|
auto* fval_buff_row_ptr = fval_buff_ptr + num_features * row_idx;
|
||||||
|
auto* miss_buff_row_ptr = miss_buff_ptr + num_features * row_idx;
|
||||||
|
|
||||||
|
const Entry* first_entry = data + row_ptr[row_idx];
|
||||||
|
const Entry* last_entry = data + row_ptr[row_idx + 1];
|
||||||
|
for (const Entry* entry = first_entry; entry < last_entry; entry += 1) {
|
||||||
|
fval_buff_row_ptr[entry->index] = entry->fvalue;
|
||||||
|
if constexpr (any_missing) {
|
||||||
|
miss_buff_row_ptr[entry->index] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (num_group == 1) {
|
if (num_group == 1) {
|
||||||
float sum = 0.0;
|
float sum = 0.0;
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
const DeviceNode* tree = nodes + tree_segments[tree_idx - tree_begin];
|
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
|
||||||
sum += GetLeafWeight(global_idx, tree, data, row_ptr);
|
if constexpr (any_missing) {
|
||||||
|
sum += GetLeafWeight(first_node, fval_buff_row_ptr, miss_buff_row_ptr);
|
||||||
|
} else {
|
||||||
|
sum += GetLeafWeight(first_node, fval_buff_row_ptr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
out_predictions[global_idx] += sum;
|
out_predictions[row_idx] += sum;
|
||||||
} else {
|
} else {
|
||||||
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
|
||||||
const DeviceNode* tree = nodes + tree_segments[tree_idx - tree_begin];
|
const Node* first_node = nodes + first_node_position[tree_idx - tree_begin];
|
||||||
int out_prediction_idx = global_idx * num_group + tree_group[tree_idx];
|
int out_prediction_idx = row_idx * num_group + tree_group[tree_idx];
|
||||||
out_predictions[out_prediction_idx] += GetLeafWeight(global_idx, tree, data, row_ptr);
|
if constexpr (any_missing) {
|
||||||
|
out_predictions[out_prediction_idx] +=
|
||||||
|
GetLeafWeight(first_node, fval_buff_row_ptr, miss_buff_row_ptr);
|
||||||
|
} else {
|
||||||
|
out_predictions[out_prediction_idx] +=
|
||||||
|
GetLeafWeight(first_node, fval_buff_row_ptr);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}).wait();
|
});
|
||||||
|
qu->wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
class Predictor : public xgboost::Predictor {
|
class Predictor : public xgboost::Predictor {
|
||||||
protected:
|
public:
|
||||||
void InitOutPredictions(const MetaInfo& info,
|
void InitOutPredictions(const MetaInfo& info,
|
||||||
HostDeviceVector<bst_float>* out_preds,
|
HostDeviceVector<bst_float>* out_preds,
|
||||||
const gbm::GBTreeModel& model) const override {
|
const gbm::GBTreeModel& model) const override {
|
||||||
@ -263,7 +271,6 @@ class Predictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
|
||||||
explicit Predictor(Context const* context) :
|
explicit Predictor(Context const* context) :
|
||||||
xgboost::Predictor::Predictor{context},
|
xgboost::Predictor::Predictor{context},
|
||||||
cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {}
|
cpu_predictor(xgboost::Predictor::Create("cpu_predictor", context)) {}
|
||||||
@ -281,7 +288,12 @@ class Predictor : public xgboost::Predictor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (tree_begin < tree_end) {
|
if (tree_begin < tree_end) {
|
||||||
DevicePredictInternal(qu, &device_matrix, out_preds, model, tree_begin, tree_end);
|
const bool any_missing = !(dmat->IsDense());
|
||||||
|
if (any_missing) {
|
||||||
|
DevicePredictInternal<true>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
|
||||||
|
} else {
|
||||||
|
DevicePredictInternal<false>(&qu, device_matrix, out_preds, model, tree_begin, tree_end);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user