Some comments for row partitioner. (#4832)
This commit is contained in:
@@ -152,8 +152,8 @@ struct ELLPackMatrix {
|
||||
|
||||
XGBOOST_DEVICE size_t BinCount() const { return gidx_fvalue_map.size(); }
|
||||
|
||||
// Get a matrix element, uses binary search for look up
|
||||
// Return NaN if missing
|
||||
// Get a matrix element, uses binary search for look up Return NaN if missing
|
||||
// Given a row index and a feature index, returns the corresponding cut value
|
||||
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
|
||||
auto row_begin = row_stride * ridx;
|
||||
auto row_end = row_begin + row_stride;
|
||||
@@ -832,14 +832,15 @@ struct DeviceShard {
|
||||
row_partitioner->UpdatePosition(
|
||||
nidx, split_node.LeftChild(), split_node.RightChild(),
|
||||
[=] __device__(bst_uint ridx) {
|
||||
bst_float element =
|
||||
// given a row index, returns the node id it belongs to
|
||||
bst_float cut_value =
|
||||
d_matrix.GetElement(ridx, split_node.SplitIndex());
|
||||
// Missing value
|
||||
int new_position = 0;
|
||||
if (isnan(element)) {
|
||||
if (isnan(cut_value)) {
|
||||
new_position = split_node.DefaultChild();
|
||||
} else {
|
||||
if (element <= split_node.SplitCond()) {
|
||||
if (cut_value <= split_node.SplitCond()) {
|
||||
new_position = split_node.LeftChild();
|
||||
} else {
|
||||
new_position = split_node.RightChild();
|
||||
|
||||
Reference in New Issue
Block a user