diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 0464e747b..c03587c14 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "./param.h" @@ -75,17 +76,21 @@ void DumpRegTree(std::stringstream& fo, // NOLINT(*) break; } case FeatureMap::kInteger: { + const bst_float floored = std::floor(cond); + const int integer_threshold + = (floored == cond) ? static_cast(floored) + : static_cast(floored) + 1; if (format == "json") { fo << "{ \"nodeid\": " << nid << ", \"depth\": " << depth << ", \"split\": \"" << fmap.Name(split_index) << "\"" - << ", \"split_condition\": " << int(cond + 1.0) + << ", \"split_condition\": " << integer_threshold << ", \"yes\": " << tree[nid].LeftChild() << ", \"no\": " << tree[nid].RightChild() << ", \"missing\": " << tree[nid].DefaultChild(); } else { fo << nid << ":[" << fmap.Name(split_index) << "<" - << int(cond + 1.0) + << integer_threshold << "] yes=" << tree[nid].LeftChild() << ",no=" << tree[nid].RightChild() << ",missing=" << tree[nid].DefaultChild();