From 8bdf15120a3b2f2f698f6f2444fb145fe9a71817 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 26 Jun 2019 15:20:44 +0800 Subject: [PATCH] Implement tree model dump with code generator. (#4602) * Implement tree model dump with a code generator. * Split up generators. * Implement graphviz generator. * Use pattern matching. * [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package. Co-Authored-By: Philip Hyunsu Cho --- include/xgboost/feature_map.h | 2 + python-package/xgboost/core.py | 2 +- python-package/xgboost/plotting.py | 165 +++---- src/c_api/c_api.cc | 49 +- src/common/timer.h | 6 +- src/tree/tree_model.cc | 705 +++++++++++++++++++++++------ src/tree/updater_gpu.cu | 2 +- tests/cpp/tree/test_tree_model.cc | 117 +++++ tests/python/test_plotting.py | 4 +- tests/python/test_shap.py | 1 - tests/python/test_with_sklearn.py | 13 +- 11 files changed, 802 insertions(+), 264 deletions(-) diff --git a/include/xgboost/feature_map.h b/include/xgboost/feature_map.h index 2ccc16530..56dd126a5 100644 --- a/include/xgboost/feature_map.h +++ b/include/xgboost/feature_map.h @@ -7,6 +7,8 @@ #ifndef XGBOOST_FEATURE_MAP_H_ #define XGBOOST_FEATURE_MAP_H_ +#include + #include #include #include diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index dbf34e051..2dc1cafad 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1419,7 +1419,7 @@ class Booster(object): with_stats : bool, optional Controls whether the split statistics are output. dump_format : string, optional - Format of model dump. Can be 'text' or 'json'. + Format of model dump. Can be 'text', 'json' or 'dot'. """ length = c_bst_ulong() sarr = ctypes.POINTER(ctypes.c_char_p)() diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 094bf8a51..5ac8d177d 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -1,10 +1,7 @@ -# coding: utf-8 # pylint: disable=too-many-locals, too-many-arguments, invalid-name, # pylint: disable=too-many-branches +# coding: utf-8 """Plotting Library.""" -from __future__ import absolute_import - -import re from io import BytesIO import numpy as np from .core import Booster @@ -61,7 +58,8 @@ def plot_importance(booster, ax=None, height=0.2, raise ImportError('You must install matplotlib to plot importance') if isinstance(booster, XGBModel): - importance = booster.get_booster().get_score(importance_type=importance_type) + importance = booster.get_booster().get_score( + importance_type=importance_type) elif isinstance(booster, Booster): importance = booster.get_score(importance_type=importance_type) elif isinstance(booster, dict): @@ -117,56 +115,11 @@ def plot_importance(booster, ax=None, height=0.2, return ax -_NODEPAT = re.compile(r'(\d+):\[(.+)\]') -_LEAFPAT = re.compile(r'(\d+):(leaf=.+)') -_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)') -_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)') - - -def _parse_node(graph, text, condition_node_params, leaf_node_params): - """parse dumped node""" - match = _NODEPAT.match(text) - if match is not None: - node = match.group(1) - graph.node(node, label=match.group(2), **condition_node_params) - return node - match = _LEAFPAT.match(text) - if match is not None: - node = match.group(1) - graph.node(node, label=match.group(2), **leaf_node_params) - return node - raise ValueError('Unable to parse node: {0}'.format(text)) - - -def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'): - """parse dumped edge""" - try: - match = _EDGEPAT.match(text) - if match is not None: - yes, no, missing = match.groups() - if yes == missing: - graph.edge(node, yes, label='yes, missing', color=yes_color) - graph.edge(node, no, label='no', color=no_color) - else: - graph.edge(node, yes, label='yes', color=yes_color) - graph.edge(node, no, label='no, missing', color=no_color) - return - except ValueError: - pass - match = _EDGEPAT2.match(text) - if match is not None: - yes, no = match.groups() - graph.edge(node, yes, label='yes', color=yes_color) - graph.edge(node, no, label='no', color=no_color) - return - raise ValueError('Unable to parse edge: {0}'.format(text)) - - -def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', - yes_color='#0000FF', no_color='#FF0000', +def to_graphviz(booster, fmap='', num_trees=0, rankdir=None, + yes_color=None, no_color=None, condition_node_params=None, leaf_node_params=None, **kwargs): - """Convert specified tree to graphviz instance. IPython can automatically plot the - returned graphiz instance. Otherwise, you should call .render() method + """Convert specified tree to graphviz instance. IPython can automatically plot + the returned graphiz instance. Otherwise, you should call .render() method of the returned graphiz instance. Parameters @@ -184,64 +137,77 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT', no_color : str, default '#FF0000' Edge color when doesn't meet the node condition. condition_node_params : dict (optional) - condition node configuration, - {'shape':'box', - 'style':'filled,rounded', - 'fillcolor':'#78bceb'} + Condition node configuration for for graphviz. Example: + + .. code-block:: python + + {'shape': 'box', + 'style': 'filled,rounded', + 'fillcolor': '#78bceb'} leaf_node_params : dict (optional) - leaf node configuration - {'shape':'box', - 'style':'filled', - 'fillcolor':'#e48038'} + Leaf node configuration for graphviz. Example: - kwargs : - Other keywords passed to graphviz graph_attr + .. code-block:: python + + {'shape': 'box', + 'style': 'filled', + 'fillcolor': '#e48038'} + + kwargs : Other keywords passed to graphviz graph_attr, E.g.: + ``graph [ {key} = {value} ]`` Returns ------- - ax : matplotlib Axes + graph: graphviz.Source + """ - - if condition_node_params is None: - condition_node_params = {} - if leaf_node_params is None: - leaf_node_params = {} - try: - from graphviz import Digraph + from graphviz import Source except ImportError: raise ImportError('You must install graphviz to plot tree') - - if not isinstance(booster, (Booster, XGBModel)): - raise ValueError('booster must be Booster or XGBModel instance') - if isinstance(booster, XGBModel): booster = booster.get_booster() - tree = booster.get_dump(fmap=fmap)[num_trees] - tree = tree.split() + # squash everything back into kwargs again for compatibility + parameters = 'dot' + extra = {} + for key, value in kwargs.items(): + extra[key] = value - kwargs = kwargs.copy() - kwargs.update({'rankdir': rankdir}) - graph = Digraph(graph_attr=kwargs) - - for i, text in enumerate(tree): - if text[0].isdigit(): - node = _parse_node( - graph, text, condition_node_params=condition_node_params, - leaf_node_params=leaf_node_params) + if rankdir is not None: + kwargs['graph_attrs'] = {} + kwargs['graph_attrs']['rankdir'] = rankdir + for key, value in extra.items(): + if 'graph_attrs' in kwargs.keys(): + kwargs['graph_attrs'][key] = value else: - if i == 0: - # 1st string must be node - raise ValueError('Unable to parse given string as tree') - _parse_edge(graph, node, text, yes_color=yes_color, - no_color=no_color) + kwargs['graph_attrs'] = {} + del kwargs[key] - return graph + if yes_color is not None or no_color is not None: + kwargs['edge'] = {} + if yes_color is not None: + kwargs['edge']['yes_color'] = yes_color + if no_color is not None: + kwargs['edge']['no_color'] = no_color + + if condition_node_params is not None: + kwargs['condition_node_params'] = condition_node_params + if leaf_node_params is not None: + kwargs['leaf_node_params'] = leaf_node_params + + if kwargs: + parameters += ':' + parameters += str(kwargs) + tree = booster.get_dump( + fmap=fmap, + dump_format=parameters)[num_trees] + g = Source(tree) + return g -def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs): +def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs): """Plot specified tree. Parameters @@ -252,7 +218,7 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs): The name of feature map file num_trees : int, default 0 Specify the ordinal number of target tree - rankdir : str, default "UT" + rankdir : str, default "TB" Passed to graphiz via graph_attr ax : matplotlib Axes, default None Target axes instance. If None, new figure and axes will be created. @@ -264,18 +230,17 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs): ax : matplotlib Axes """ - try: - import matplotlib.pyplot as plt - import matplotlib.image as image + from matplotlib import pyplot as plt + from matplotlib import image except ImportError: raise ImportError('You must install matplotlib to plot tree') if ax is None: _, ax = plt.subplots(1, 1) - g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, - rankdir=rankdir, **kwargs) + g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir, + **kwargs) s = BytesIO() s.write(g.pipe(format='png')) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 3d85b94c3..8e44b6e61 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1033,19 +1033,21 @@ inline void XGBoostDumpModelImpl( *out_models = dmlc::BeginPtr(charp_vecs); *len = static_cast(charp_vecs.size()); } + XGB_DLL int XGBoosterDumpModel(BoosterHandle handle, - const char* fmap, - int with_stats, - xgboost::bst_ulong* len, - const char*** out_models) { + const char* fmap, + int with_stats, + xgboost::bst_ulong* len, + const char*** out_models) { return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models); } + XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, - const char* fmap, - int with_stats, - const char *format, - xgboost::bst_ulong* len, - const char*** out_models) { + const char* fmap, + int with_stats, + const char *format, + xgboost::bst_ulong* len, + const char*** out_models) { API_BEGIN(); CHECK_HANDLE(); FeatureMap featmap; @@ -1060,23 +1062,24 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle, } XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle, - int fnum, - const char** fname, - const char** ftype, - int with_stats, - xgboost::bst_ulong* len, - const char*** out_models) { + int fnum, + const char** fname, + const char** ftype, + int with_stats, + xgboost::bst_ulong* len, + const char*** out_models) { return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats, - "text", len, out_models); + "text", len, out_models); } + XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle, - int fnum, - const char** fname, - const char** ftype, - int with_stats, - const char *format, - xgboost::bst_ulong* len, - const char*** out_models) { + int fnum, + const char** fname, + const char** ftype, + int with_stats, + const char *format, + xgboost::bst_ulong* len, + const char*** out_models) { API_BEGIN(); CHECK_HANDLE(); FeatureMap featmap; diff --git a/src/common/timer.h b/src/common/timer.h index c9aa9de18..72db5d8fc 100644 --- a/src/common/timer.h +++ b/src/common/timer.h @@ -10,7 +10,7 @@ #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) #include -#endif +#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) namespace xgboost { namespace common { @@ -98,7 +98,7 @@ struct Monitor { stats.timer.Start(); #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) stats.nvtx_id = nvtxRangeStartA(name.c_str()); -#endif +#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) } } void StopCuda(const std::string &name) { @@ -108,7 +108,7 @@ struct Monitor { stats.count++; #if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) nvtxRangeEnd(stats.nvtx_id); -#endif +#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__) } } }; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 88aa795d0..9d66eef6a 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,14 +1,19 @@ /*! - * Copyright 2015 by Contributors + * Copyright 2015-2019 by Contributors * \file tree_model.cc * \brief model structure for tree */ +#include +#include + #include +#include #include #include #include #include -#include "./param.h" + +#include "param.h" namespace xgboost { // register tree parameter @@ -17,158 +22,602 @@ DMLC_REGISTER_PARAMETER(TreeParam); namespace tree { DMLC_REGISTER_PARAMETER(TrainParam); } -// internal function to dump regression tree to text -void DumpRegTree(std::stringstream& fo, // NOLINT(*) - const RegTree& tree, - const FeatureMap& fmap, - int nid, int depth, int add_comma, - bool with_stats, std::string format) { - int float_max_precision = std::numeric_limits::max_digits10; - if (format == "json") { - if (add_comma) { - fo << ","; - } - if (depth != 0) { - fo << std::endl; - } - for (int i = 0; i < depth + 1; ++i) { - fo << " "; - } - } else { - for (int i = 0; i < depth; ++i) { - fo << '\t'; - } + +/*! + * \brief Base class for dump model implementation, modeling closely after code generator. + */ +class TreeGenerator { + protected: + static int32_t constexpr kFloatMaxPrecision = + std::numeric_limits::max_digits10; + FeatureMap const& fmap_; + std::stringstream ss_; + bool const with_stats_; + + template + static std::string ToStr(Float value) { + static_assert(std::is_floating_point::value, + "Use std::to_string instead for non-floating point values."); + std::stringstream ss; + ss << std::setprecision(kFloatMaxPrecision) << value; + return ss.str(); } - if (tree[nid].IsLeaf()) { - if (format == "json") { - fo << "{ \"nodeid\": " << nid - << ", \"leaf\": " << std::setprecision(float_max_precision) << tree[nid].LeafValue(); - if (with_stats) { - fo << ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess; - } - fo << " }"; - } else { - fo << nid << ":leaf=" << std::setprecision(float_max_precision) << tree[nid].LeafValue(); - if (with_stats) { - fo << ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess; - } - fo << '\n'; + + static std::string Tabs(uint32_t n) { + std::string res; + for (uint32_t i = 0; i < n; ++i) { + res += '\t'; } - } else { - // right then left, - bst_float cond = tree[nid].SplitCond(); - const unsigned split_index = tree[nid].SplitIndex(); - if (split_index < fmap.Size()) { - switch (fmap.type(split_index)) { + return res; + } + /* \brief Find the first occurance of key in input and replace it with corresponding + * value. + */ + static std::string Match(std::string const& input, + std::map const& replacements) { + std::string result = input; + for (auto const& kv : replacements) { + auto pos = result.find(kv.first); + CHECK_NE(pos, std::string::npos); + result.replace(pos, kv.first.length(), kv.second); + } + return result; + } + + virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) { + return ""; + } + virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) { + return ""; + } + virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) { + return ""; + } + virtual std::string NodeStat(RegTree const& tree, int32_t nid) { + return ""; + } + + virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0; + + virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) { + auto const split_index = tree[nid].SplitIndex(); + std::string result; + if (split_index < fmap_.Size()) { + switch (fmap_.type(split_index)) { case FeatureMap::kIndicator: { - int nyes = tree[nid].DefaultLeft() ? - tree[nid].RightChild() : tree[nid].LeftChild(); - if (format == "json") { - fo << "{ \"nodeid\": " << nid - << ", \"depth\": " << depth - << ", \"split\": \"" << fmap.Name(split_index) << "\"" - << ", \"yes\": " << nyes - << ", \"no\": " << tree[nid].DefaultChild(); - } else { - fo << nid << ":[" << fmap.Name(split_index) << "] yes=" << nyes - << ",no=" << tree[nid].DefaultChild(); - } + result = this->Indicator(tree, nid, depth); break; } case FeatureMap::kInteger: { - const bst_float floored = std::floor(cond); - const int integer_threshold - = (floored == cond) ? static_cast(floored) - : static_cast(floored) + 1; - if (format == "json") { - fo << "{ \"nodeid\": " << nid - << ", \"depth\": " << depth - << ", \"split\": \"" << fmap.Name(split_index) << "\"" - << ", \"split_condition\": " << integer_threshold - << ", \"yes\": " << tree[nid].LeftChild() - << ", \"no\": " << tree[nid].RightChild() - << ", \"missing\": " << tree[nid].DefaultChild(); - } else { - fo << nid << ":[" << fmap.Name(split_index) << "<" - << integer_threshold - << "] yes=" << tree[nid].LeftChild() - << ",no=" << tree[nid].RightChild() - << ",missing=" << tree[nid].DefaultChild(); - } + result = this->Integer(tree, nid, depth); break; } case FeatureMap::kFloat: case FeatureMap::kQuantitive: { - if (format == "json") { - fo << "{ \"nodeid\": " << nid - << ", \"depth\": " << depth - << ", \"split\": \"" << fmap.Name(split_index) << "\"" - << ", \"split_condition\": " << std::setprecision(float_max_precision) << cond - << ", \"yes\": " << tree[nid].LeftChild() - << ", \"no\": " << tree[nid].RightChild() - << ", \"missing\": " << tree[nid].DefaultChild(); - } else { - fo << nid << ":[" << fmap.Name(split_index) - << "<" << std::setprecision(float_max_precision) << cond - << "] yes=" << tree[nid].LeftChild() - << ",no=" << tree[nid].RightChild() - << ",missing=" << tree[nid].DefaultChild(); - } + result = this->Quantitive(tree, nid, depth); break; } - default: LOG(FATAL) << "unknown fmap type"; - } + default: + LOG(FATAL) << "Unknown feature map type."; + } } else { - if (format == "json") { - fo << "{ \"nodeid\": " << nid - << ", \"depth\": " << depth - << ", \"split\": " << split_index - << ", \"split_condition\": " << std::setprecision(float_max_precision) << cond - << ", \"yes\": " << tree[nid].LeftChild() - << ", \"no\": " << tree[nid].RightChild() - << ", \"missing\": " << tree[nid].DefaultChild(); - } else { - fo << nid << ":[f" << split_index << "<"<< std::setprecision(float_max_precision) << cond - << "] yes=" << tree[nid].LeftChild() - << ",no=" << tree[nid].RightChild() - << ",missing=" << tree[nid].DefaultChild(); + result = this->PlainNode(tree, nid, depth); + } + return result; + } + + virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0; + virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0; + + public: + TreeGenerator(FeatureMap const& _fmap, bool with_stats) : + fmap_{_fmap}, with_stats_{with_stats} {} + virtual ~TreeGenerator() = default; + + virtual void BuildTree(RegTree const& tree) { + ss_ << this->BuildTree(tree, 0, 0); + } + + std::string Str() const { + return ss_.str(); + } + + static TreeGenerator* Create(std::string const& attrs, FeatureMap const& fmap, + bool with_stats); +}; + +struct TreeGenReg : public dmlc::FunctionRegEntryBase< + TreeGenReg, + std::function > { +}; +} // namespace xgboost + + +namespace dmlc { +DMLC_REGISTRY_ENABLE(::xgboost::TreeGenReg); +} // namespace dmlc + +namespace xgboost { + +TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& fmap, + bool with_stats) { + auto pos = attrs.find(':'); + std::string name; + std::string params; + if (pos != std::string::npos) { + name = attrs.substr(0, pos); + params = attrs.substr(pos+1, attrs.length() - pos - 1); + // Eliminate all occurances of single quote string. + size_t pos = std::string::npos; + while ((pos = params.find('\'')) != std::string::npos) { + params.replace(pos, 1, "\""); + } + } else { + name = attrs; + } + auto *e = ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->Find(name); + if (e == nullptr) { + LOG(FATAL) << "Unknown Model Builder:" << name; + } + auto p_io_builder = (e->body)(fmap, params, with_stats); + return p_io_builder; +} + +#define XGBOOST_REGISTER_TREE_IO(UniqueId, Name) \ + static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeGenReg& \ + __make_ ## TreeGenReg ## _ ## UniqueId ## __ = \ + ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name) + + +class TextGenerator : public TreeGenerator { + using SuperT = TreeGenerator; + + public: + TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) : + TreeGenerator(fmap, with_stats) {} + + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}"; + static std::string kStatTemplate = ",cover={cover}"; + std::string result = SuperT::Match( + kLeafTemplate, + {{"{tabs}", SuperT::Tabs(depth)}, + {"{nid}", std::to_string(nid)}, + {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{stats}", with_stats_ ? + SuperT::Match(kStatTemplate, + {{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); + return result; + } + + std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}"; + int32_t nyes = tree[nid].DefaultLeft() ? + tree[nid].RightChild() : tree[nid].LeftChild(); + auto split_index = tree[nid].SplitIndex(); + std::string result = SuperT::Match( + kIndicatorTemplate, + {{"{nid}", std::to_string(nid)}, + {"{fname}", fmap_.Name(split_index)}, + {"{yes}", std::to_string(nyes)}, + {"{no}", std::to_string(tree[nid].DefaultChild())}}); + return result; + } + + std::string SplitNodeImpl( + RegTree const& tree, int32_t nid, std::string const& template_str, + std::string cond, uint32_t depth) { + auto split_index = tree[nid].SplitIndex(); + std::string const result = SuperT::Match( + template_str, + {{"{tabs}", SuperT::Tabs(depth)}, + {"{nid}", std::to_string(nid)}, + {"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) : + std::to_string(split_index)}, + {"{cond}", cond}, + {"{left}", std::to_string(tree[nid].LeftChild())}, + {"{right}", std::to_string(tree[nid].RightChild())}, + {"{missing}", std::to_string(tree[nid].DefaultChild())}}); + return result; + } + + std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kIntegerTemplate = + "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; + auto cond = tree[nid].SplitCond(); + const bst_float floored = std::floor(cond); + const int32_t integer_threshold + = (floored == cond) ? static_cast(floored) + : static_cast(floored) + 1; + return SplitNodeImpl(tree, nid, kIntegerTemplate, + std::to_string(integer_threshold), depth); + } + + std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kQuantitiveTemplate = + "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; + auto cond = tree[nid].SplitCond(); + return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); + } + + std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + auto cond = tree[nid].SplitCond(); + static std::string const kNodeTemplate = + "{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}"; + return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); + } + + std::string NodeStat(RegTree const& tree, int32_t nid) override { + static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; + std::string const result = SuperT::Match( + kStatTemplate, + {{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, + {"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); + return result; + } + + std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override { + if (tree[nid].IsLeaf()) { + return this->LeafNode(tree, nid, depth); + } + static std::string const kNodeTemplate = "{parent}{stat}\n{left}\n{right}"; + auto result = SuperT::Match( + kNodeTemplate, + {{"{parent}", this->SplitNode(tree, nid, depth)}, + {"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""}, + {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, + {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); + return result; + } + + void BuildTree(RegTree const& tree) override { + static std::string const& kTreeTemplate = "{nodes}\n"; + auto result = SuperT::Match( + kTreeTemplate, + {{"{nodes}", this->BuildTree(tree, 0, 0)}}); + ss_ << result; + } +}; + +XGBOOST_REGISTER_TREE_IO(TextGenerator, "text") +.describe("Dump text representation of tree") +.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) { + return new TextGenerator(fmap, attrs, with_stats); + }); + +class JsonGenerator : public TreeGenerator { + using SuperT = TreeGenerator; + + public: + JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) : + TreeGenerator(fmap, with_stats) {} + + std::string Indent(uint32_t depth) { + std::string result; + for (uint32_t i = 0; i < depth + 1; ++i) { + result += " "; + } + return result; + } + + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kLeafTemplate = + R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L"; + static std::string const kStatTemplate = + R"S(, "cover": {sum_hess} )S"; + std::string result = SuperT::Match( + kLeafTemplate, + {{"{nid}", std::to_string(nid)}, + {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{stat}", with_stats_ ? SuperT::Match( + kStatTemplate, + {{"{sum_hess}", + SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); + return result; + } + + std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override { + int32_t nyes = tree[nid].DefaultLeft() ? + tree[nid].RightChild() : tree[nid].LeftChild(); + static std::string const kIndicatorTemplate = + R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID"; + auto split_index = tree[nid].SplitIndex(); + auto result = SuperT::Match( + kIndicatorTemplate, + {{"{nid}", std::to_string(nid)}, + {"{depth}", std::to_string(depth)}, + {"{fname}", fmap_.Name(split_index)}, + {"{yes}", std::to_string(nyes)}, + {"{no}", std::to_string(tree[nid].DefaultChild())}}); + return result; + } + + std::string SplitNodeImpl(RegTree const& tree, int32_t nid, + std::string const& template_str, std::string cond, uint32_t depth) { + auto split_index = tree[nid].SplitIndex(); + std::string const result = SuperT::Match( + template_str, + {{"{nid}", std::to_string(nid)}, + {"{depth}", std::to_string(depth)}, + {"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) : + std::to_string(split_index)}, + {"{cond}", cond}, + {"{left}", std::to_string(tree[nid].LeftChild())}, + {"{right}", std::to_string(tree[nid].RightChild())}, + {"{missing}", std::to_string(tree[nid].DefaultChild())}}); + return result; + } + + std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override { + auto cond = tree[nid].SplitCond(); + const bst_float floored = std::floor(cond); + const int32_t integer_threshold + = (floored == cond) ? static_cast(floored) + : static_cast(floored) + 1; + static std::string const kIntegerTemplate = + R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" + R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" + R"I("missing": {missing})I"; + return SplitNodeImpl(tree, nid, kIntegerTemplate, + std::to_string(integer_threshold), depth); + } + + std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kQuantitiveTemplate = + R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" + R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" + R"I("missing": {missing})I"; + bst_float cond = tree[nid].SplitCond(); + return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); + } + + std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + auto cond = tree[nid].SplitCond(); + static std::string const kNodeTemplate = + R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I" + R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" + R"I("missing": {missing})I"; + return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); + } + + std::string NodeStat(RegTree const& tree, int32_t nid) override { + static std::string kStatTemplate = + R"S(, "gain": {loss_chg}, "cover": {sum_hess})S"; + auto result = SuperT::Match( + kStatTemplate, + {{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, + {"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); + return result; + } + + std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + std::string properties = SuperT::SplitNode(tree, nid, depth); + static std::string const kSplitNodeTemplate = + "{{properties} {stat}, \"children\": [{left}, {right}\n{indent}]}"; + auto result = SuperT::Match( + kSplitNodeTemplate, + {{"{properties}", properties}, + {"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""}, + {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, + {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}, + {"{indent}", this->Indent(depth)}}); + return result; + } + + std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kNodeTemplate = "{newline}{indent}{nodes}"; + auto result = SuperT::Match( + kNodeTemplate, + {{"{newline}", depth == 0 ? "" : "\n"}, + {"{indent}", Indent(depth)}, + {"{nodes}", tree[nid].IsLeaf() ? this->LeafNode(tree, nid, depth) : + this->SplitNode(tree, nid, depth)}}); + return result; + } +}; + +XGBOOST_REGISTER_TREE_IO(JsonGenerator, "json") +.describe("Dump json representation of tree") +.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) { + return new JsonGenerator(fmap, attrs, with_stats); + }); + +struct GraphvizParam : public dmlc::Parameter { + std::string yes_color; + std::string no_color; + std::string rankdir; + std::string condition_node_params; + std::string leaf_node_params; + std::string graph_attrs; + + DMLC_DECLARE_PARAMETER(GraphvizParam){ + DMLC_DECLARE_FIELD(yes_color) + .set_default("#0000FF") + .describe("Edge color when meets the node condition."); + DMLC_DECLARE_FIELD(no_color) + .set_default("#FF0000") + .describe("Edge color when doesn't meet the node condition."); + DMLC_DECLARE_FIELD(rankdir) + .set_default("TB") + .describe("Passed to graphiz via graph_attr."); + DMLC_DECLARE_FIELD(condition_node_params) + .set_default("") + .describe("Conditional node configuration"); + DMLC_DECLARE_FIELD(leaf_node_params) + .set_default("") + .describe("Leaf node configuration"); + DMLC_DECLARE_FIELD(graph_attrs) + .set_default("") + .describe("Any other extra attributes for graphviz `graph_attr`."); + } +}; + +DMLC_REGISTER_PARAMETER(GraphvizParam); + +class GraphvizGenerator : public TreeGenerator { + using SuperT = TreeGenerator; + std::stringstream& ss_; + GraphvizParam param_; + + public: + GraphvizGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) : + TreeGenerator(fmap, with_stats), ss_{SuperT::ss_} { + param_.InitAllowUnknown(std::map{}); + using KwArg = std::map>; + KwArg kwargs; + if (attrs.length() != 0) { + std::istringstream iss(attrs); + try { + dmlc::JSONReader reader(&iss); + reader.Read(&kwargs); + } catch(dmlc::Error const& e) { + LOG(FATAL) << "Failed to parse graphviz parameters:\n\t" + << attrs << "\n" + << "With error:\n" + << e.what(); } } - if (with_stats) { - if (format == "json") { - fo << ", \"gain\": " << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg - << ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess; - } else { - fo << ",gain=" << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg - << ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess; + // This turns out to be tricky, as `dmlc::Parameter::Load(JSONReader*)` doesn't + // support loading nested json objects. + if (kwargs.find("condition_node_params") != kwargs.cend()) { + auto const& cnp = kwargs["condition_node_params"]; + for (auto const& kv : cnp) { + param_.condition_node_params += kv.first + '=' + "\"" + kv.second + "\" "; } + kwargs.erase("condition_node_params"); } - if (format == "json") { - fo << ", \"children\": ["; - } else { - fo << '\n'; - } - DumpRegTree(fo, tree, fmap, tree[nid].LeftChild(), depth + 1, false, with_stats, format); - DumpRegTree(fo, tree, fmap, tree[nid].RightChild(), depth + 1, true, with_stats, format); - if (format == "json") { - fo << std::endl; - for (int i = 0; i < depth + 1; ++i) { - fo << " "; + if (kwargs.find("leaf_node_params") != kwargs.cend()) { + auto const& lnp = kwargs["leaf_node_params"]; + for (auto const& kv : lnp) { + param_.leaf_node_params += kv.first + '=' + "\"" + kv.second + "\" "; } - fo << "]}"; + kwargs.erase("leaf_node_params"); + } + + if (kwargs.find("edge") != kwargs.cend()) { + if (kwargs["edge"].find("yes_color") != kwargs["edge"].cend()) { + param_.yes_color = kwargs["edge"]["yes_color"]; + } + if (kwargs["edge"].find("no_color") != kwargs["edge"].cend()) { + param_.no_color = kwargs["edge"]["no_color"]; + } + kwargs.erase("edge"); + } + auto const& extra = kwargs["graph_attrs"]; + static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n"; + for (auto const& kv : extra) { + param_.graph_attrs += SuperT::Match(kGraphTemplate, + {{"{key}", kv.first}, + {"{value}", kv.second}}); + } + + kwargs.erase("graph_attrs"); + if (kwargs.size() != 0) { + std::stringstream ss; + ss << "The following parameters for graphviz are not recognized:\n"; + for (auto kv : kwargs) { + ss << kv.first << ", "; + } + LOG(WARNING) << ss.str(); } } -} + + protected: + // Only indicator is different, so we combine all different node types into this + // function. + std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + auto split = tree[nid].SplitIndex(); + auto cond = tree[nid].SplitCond(); + static std::string const kNodeTemplate = + " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; + + // Indicator only has fname. + bool has_less = (split >= fmap_.Size()) || fmap_.type(split) != FeatureMap::kIndicator; + std::string result = SuperT::Match(kNodeTemplate, { + {"{nid}", std::to_string(nid)}, + {"{fname}", split < fmap_.Size() ? fmap_.Name(split) : + 'f' + std::to_string(split)}, + {"{<}", has_less ? "<" : ""}, + {"{cond}", has_less ? SuperT::ToStr(cond) : ""}, + {"{params}", param_.condition_node_params}}); + + static std::string const kEdgeTemplate = + " {nid} -> {child} [label=\"{is_missing}\" color=\"{color}\"]\n"; + auto MatchFn = SuperT::Match; // mingw failed to capture protected fn. + auto BuildEdge = + [&tree, nid, MatchFn, this](int32_t child) { + bool is_missing = tree[nid].DefaultChild() == child; + std::string buffer = MatchFn(kEdgeTemplate, { + {"{nid}", std::to_string(nid)}, + {"{child}", std::to_string(child)}, + {"{color}", is_missing ? param_.yes_color : param_.no_color}, + {"{is_missing}", is_missing ? "yes, missing": "no"}}); + return buffer; + }; + result += BuildEdge(tree[nid].LeftChild()); + result += BuildEdge(tree[nid].RightChild()); + return result; + }; + + std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override { + static std::string const kLeafTemplate = + " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; + auto result = SuperT::Match(kLeafTemplate, { + {"{nid}", std::to_string(nid)}, + {"{leaf-value}", ToStr(tree[nid].LeafValue())}, + {"{params}", param_.leaf_node_params}}); + return result; + }; + + std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override { + if (tree[nid].IsLeaf()) { + return this->LeafNode(tree, nid, depth); + } + static std::string const kNodeTemplate = "{parent}\n{left}\n{right}"; + auto result = SuperT::Match( + kNodeTemplate, + {{"{parent}", this->PlainNode(tree, nid, depth)}, + {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, + {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); + return result; + } + + void BuildTree(RegTree const& tree) override { + static std::string const kTreeTemplate = + "digraph {\n" + " graph [ rankdir={rankdir} ]\n" + "{graph_attrs}\n" + "{nodes}}"; + auto result = SuperT::Match( + kTreeTemplate, + {{"{rankdir}", param_.rankdir}, + {"{graph_attrs}", param_.graph_attrs}, + {"{nodes}", this->BuildTree(tree, 0, 0)}}); + ss_ << result; + }; +}; + +XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot") +.describe("Dump graphviz representation of tree") +.set_body([](FeatureMap const& fmap, std::string attrs, bool with_stats) { + return new GraphvizGenerator(fmap, attrs, with_stats); + }); std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { - std::stringstream fo(""); - for (int i = 0; i < param.num_roots; ++i) { - DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format); + std::unique_ptr builder { + TreeGenerator::Create(format, fmap, with_stats) + }; + for (int32_t i = 0; i < param.num_roots; ++i) { + builder->BuildTree(*this); } - return fo.str(); + + std::string result = builder->Str(); + return result; } + void RegTree::FillNodeMeanValues() { size_t num_nodes = this->param.num_nodes; if (this->node_mean_values_.size() == num_nodes) { diff --git a/src/tree/updater_gpu.cu b/src/tree/updater_gpu.cu index 24d775f30..bbb72f8aa 100644 --- a/src/tree/updater_gpu.cu +++ b/src/tree/updater_gpu.cu @@ -144,7 +144,7 @@ __global__ void CubScanByKeyL1( int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1); #else int previousKey = __shfl_up(myKey, 1); -#endif +#endif // (__CUDACC_VER_MAJOR__ >= 9) // Collectively compute the block-wide exclusive prefix sum BlockScan(temp_storage) .ExclusiveScan(threadData, threadData, rootPair, AddByKey()); diff --git a/tests/cpp/tree/test_tree_model.cc b/tests/cpp/tree/test_tree_model.cc index da0e6bcb6..bb7d69966 100644 --- a/tests/cpp/tree/test_tree_model.cc +++ b/tests/cpp/tree/test_tree_model.cc @@ -101,4 +101,121 @@ TEST(Tree, AllocateNode) { ASSERT_TRUE(nodes.at(1).IsLeaf()); ASSERT_TRUE(nodes.at(2).IsLeaf()); } + +RegTree ConstructTree() { + RegTree tree; + tree.ExpandNode( + /*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f, + /*default_left=*/true, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + auto left = tree[0].LeftChild(); + auto right = tree[0].RightChild(); + tree.ExpandNode( + /*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f, + /*default_left=*/false, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + tree.ExpandNode( + /*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f, + /*default_left=*/false, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + return tree; +} + +TEST(Tree, DumpJson) { + auto tree = ConstructTree(); + FeatureMap fmap; + auto str = tree.DumpModel(fmap, true, "json"); + size_t n_leaves = 0; + size_t iter = 0; + while ((iter = str.find("leaf", iter + 1)) != std::string::npos) { + n_leaves++; + } + ASSERT_EQ(n_leaves, 4); + + size_t n_conditions = 0; + iter = 0; + while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) { + n_conditions++; + } + ASSERT_EQ(n_conditions, 3); + + fmap.PushBack(0, "feat_0", "i"); + fmap.PushBack(1, "feat_1", "q"); + fmap.PushBack(2, "feat_2", "int"); + + str = tree.DumpModel(fmap, true, "json"); + ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos); + ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos); + ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos); + + str = tree.DumpModel(fmap, false, "json"); + ASSERT_EQ(str.find("cover"), std::string::npos); +} + +TEST(Tree, DumpText) { + auto tree = ConstructTree(); + FeatureMap fmap; + auto str = tree.DumpModel(fmap, true, "text"); + size_t n_leaves = 0; + size_t iter = 0; + while ((iter = str.find("leaf", iter + 1)) != std::string::npos) { + n_leaves++; + } + ASSERT_EQ(n_leaves, 4); + + iter = 0; + size_t n_conditions = 0; + while ((iter = str.find("gain", iter + 1)) != std::string::npos) { + n_conditions++; + } + ASSERT_EQ(n_conditions, 3); + + ASSERT_NE(str.find("[f0<0]"), std::string::npos); + ASSERT_NE(str.find("[f1<1]"), std::string::npos); + ASSERT_NE(str.find("[f2<2]"), std::string::npos); + + fmap.PushBack(0, "feat_0", "i"); + fmap.PushBack(1, "feat_1", "q"); + fmap.PushBack(2, "feat_2", "int"); + + str = tree.DumpModel(fmap, true, "text"); + ASSERT_NE(str.find("[feat_0]"), std::string::npos); + ASSERT_NE(str.find("[feat_1<1]"), std::string::npos); + ASSERT_NE(str.find("[feat_2<2]"), std::string::npos); + + str = tree.DumpModel(fmap, false, "text"); + ASSERT_EQ(str.find("cover"), std::string::npos); +} + +TEST(Tree, DumpDot) { + auto tree = ConstructTree(); + FeatureMap fmap; + auto str = tree.DumpModel(fmap, true, "dot"); + + size_t n_leaves = 0; + size_t iter = 0; + while ((iter = str.find("leaf", iter + 1)) != std::string::npos) { + n_leaves++; + } + ASSERT_EQ(n_leaves, 4); + + size_t n_edges = 0; + iter = 0; + while ((iter = str.find("->", iter + 1)) != std::string::npos) { + n_edges++; + } + ASSERT_EQ(n_edges, 6); + + fmap.PushBack(0, "feat_0", "i"); + fmap.PushBack(1, "feat_1", "q"); + fmap.PushBack(2, "feat_2", "int"); + + str = tree.DumpModel(fmap, true, "dot"); + ASSERT_NE(str.find(R"("feat_0")"), std::string::npos); + ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos); + ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos); + + str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})"); + ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos); +} } // namespace xgboost diff --git a/tests/python/test_plotting.py b/tests/python/test_plotting.py index 7d98280e4..a5a62f8b9 100644 --- a/tests/python/test_plotting.py +++ b/tests/python/test_plotting.py @@ -10,7 +10,7 @@ try: import matplotlib matplotlib.use('Agg') from matplotlib.axes import Axes - from graphviz import Digraph + from graphviz import Source except ImportError: pass @@ -57,7 +57,7 @@ class TestPlotting(unittest.TestCase): assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue g = xgb.to_graphviz(bst2, num_trees=0) - assert isinstance(g, Digraph) + assert isinstance(g, Source) ax = xgb.plot_tree(bst2, num_trees=0) assert isinstance(ax, Axes) diff --git a/tests/python/test_shap.py b/tests/python/test_shap.py index ffddb6b8f..26580ef32 100644 --- a/tests/python/test_shap.py +++ b/tests/python/test_shap.py @@ -87,7 +87,6 @@ class TestSHAP(unittest.TestCase): r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)" r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)" for tree in model.get_dump(with_stats=True): - lines = list(tree.splitlines()) trees.append([None for i in range(len(lines))]) for line in lines: diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 0db6d1aaf..09a88ac2b 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -352,7 +352,7 @@ def test_sklearn_plotting(): matplotlib.use('Agg') from matplotlib.axes import Axes - from graphviz import Digraph + from graphviz import Source ax = xgb.plot_importance(classifier) assert isinstance(ax, Axes) @@ -362,7 +362,7 @@ def test_sklearn_plotting(): assert len(ax.patches) == 4 g = xgb.to_graphviz(classifier, num_trees=0) - assert isinstance(g, Digraph) + assert isinstance(g, Source) ax = xgb.plot_tree(classifier, num_trees=0) assert isinstance(ax, Axes) @@ -641,7 +641,8 @@ def test_XGBClassifier_resume(): X, Y = load_breast_cancer(return_X_y=True) - model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) + model1 = xgb.XGBClassifier( + learning_rate=0.3, random_state=0, n_estimators=8) model1.fit(X, Y) pred1 = model1.predict(X) @@ -649,7 +650,8 @@ def test_XGBClassifier_resume(): # file name of stored xgb model model1.save_model(model1_path) - model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) + model2 = xgb.XGBClassifier( + learning_rate=0.3, random_state=0, n_estimators=8) model2.fit(X, Y, xgb_model=model1_path) pred2 = model2.predict(X) @@ -660,7 +662,8 @@ def test_XGBClassifier_resume(): # file name of 'Booster' instance Xgb model model1.get_booster().save_model(model1_booster_path) - model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8) + model2 = xgb.XGBClassifier( + learning_rate=0.3, random_state=0, n_estimators=8) model2.fit(X, Y, xgb_model=model1_booster_path) pred2 = model2.predict(X)