Implement tree model dump with code generator. (#4602)
* Implement tree model dump with a code generator. * Split up generators. * Implement graphviz generator. * Use pattern matching. * [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package. Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
parent
fe2de6f415
commit
8bdf15120a
@ -7,6 +7,8 @@
|
|||||||
#ifndef XGBOOST_FEATURE_MAP_H_
|
#ifndef XGBOOST_FEATURE_MAP_H_
|
||||||
#define XGBOOST_FEATURE_MAP_H_
|
#define XGBOOST_FEATURE_MAP_H_
|
||||||
|
|
||||||
|
#include <xgboost/logging.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|||||||
@ -1419,7 +1419,7 @@ class Booster(object):
|
|||||||
with_stats : bool, optional
|
with_stats : bool, optional
|
||||||
Controls whether the split statistics are output.
|
Controls whether the split statistics are output.
|
||||||
dump_format : string, optional
|
dump_format : string, optional
|
||||||
Format of model dump. Can be 'text' or 'json'.
|
Format of model dump. Can be 'text', 'json' or 'dot'.
|
||||||
"""
|
"""
|
||||||
length = c_bst_ulong()
|
length = c_bst_ulong()
|
||||||
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
sarr = ctypes.POINTER(ctypes.c_char_p)()
|
||||||
|
|||||||
@ -1,10 +1,7 @@
|
|||||||
# coding: utf-8
|
|
||||||
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
|
# pylint: disable=too-many-locals, too-many-arguments, invalid-name,
|
||||||
# pylint: disable=too-many-branches
|
# pylint: disable=too-many-branches
|
||||||
|
# coding: utf-8
|
||||||
"""Plotting Library."""
|
"""Plotting Library."""
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import re
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .core import Booster
|
from .core import Booster
|
||||||
@ -61,7 +58,8 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
raise ImportError('You must install matplotlib to plot importance')
|
raise ImportError('You must install matplotlib to plot importance')
|
||||||
|
|
||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
importance = booster.get_booster().get_score(importance_type=importance_type)
|
importance = booster.get_booster().get_score(
|
||||||
|
importance_type=importance_type)
|
||||||
elif isinstance(booster, Booster):
|
elif isinstance(booster, Booster):
|
||||||
importance = booster.get_score(importance_type=importance_type)
|
importance = booster.get_score(importance_type=importance_type)
|
||||||
elif isinstance(booster, dict):
|
elif isinstance(booster, dict):
|
||||||
@ -117,56 +115,11 @@ def plot_importance(booster, ax=None, height=0.2,
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
_NODEPAT = re.compile(r'(\d+):\[(.+)\]')
|
def to_graphviz(booster, fmap='', num_trees=0, rankdir=None,
|
||||||
_LEAFPAT = re.compile(r'(\d+):(leaf=.+)')
|
yes_color=None, no_color=None,
|
||||||
_EDGEPAT = re.compile(r'yes=(\d+),no=(\d+),missing=(\d+)')
|
|
||||||
_EDGEPAT2 = re.compile(r'yes=(\d+),no=(\d+)')
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_node(graph, text, condition_node_params, leaf_node_params):
|
|
||||||
"""parse dumped node"""
|
|
||||||
match = _NODEPAT.match(text)
|
|
||||||
if match is not None:
|
|
||||||
node = match.group(1)
|
|
||||||
graph.node(node, label=match.group(2), **condition_node_params)
|
|
||||||
return node
|
|
||||||
match = _LEAFPAT.match(text)
|
|
||||||
if match is not None:
|
|
||||||
node = match.group(1)
|
|
||||||
graph.node(node, label=match.group(2), **leaf_node_params)
|
|
||||||
return node
|
|
||||||
raise ValueError('Unable to parse node: {0}'.format(text))
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_edge(graph, node, text, yes_color='#0000FF', no_color='#FF0000'):
|
|
||||||
"""parse dumped edge"""
|
|
||||||
try:
|
|
||||||
match = _EDGEPAT.match(text)
|
|
||||||
if match is not None:
|
|
||||||
yes, no, missing = match.groups()
|
|
||||||
if yes == missing:
|
|
||||||
graph.edge(node, yes, label='yes, missing', color=yes_color)
|
|
||||||
graph.edge(node, no, label='no', color=no_color)
|
|
||||||
else:
|
|
||||||
graph.edge(node, yes, label='yes', color=yes_color)
|
|
||||||
graph.edge(node, no, label='no, missing', color=no_color)
|
|
||||||
return
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
match = _EDGEPAT2.match(text)
|
|
||||||
if match is not None:
|
|
||||||
yes, no = match.groups()
|
|
||||||
graph.edge(node, yes, label='yes', color=yes_color)
|
|
||||||
graph.edge(node, no, label='no', color=no_color)
|
|
||||||
return
|
|
||||||
raise ValueError('Unable to parse edge: {0}'.format(text))
|
|
||||||
|
|
||||||
|
|
||||||
def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
|
||||||
yes_color='#0000FF', no_color='#FF0000',
|
|
||||||
condition_node_params=None, leaf_node_params=None, **kwargs):
|
condition_node_params=None, leaf_node_params=None, **kwargs):
|
||||||
"""Convert specified tree to graphviz instance. IPython can automatically plot the
|
"""Convert specified tree to graphviz instance. IPython can automatically plot
|
||||||
returned graphiz instance. Otherwise, you should call .render() method
|
the returned graphiz instance. Otherwise, you should call .render() method
|
||||||
of the returned graphiz instance.
|
of the returned graphiz instance.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -184,64 +137,77 @@ def to_graphviz(booster, fmap='', num_trees=0, rankdir='UT',
|
|||||||
no_color : str, default '#FF0000'
|
no_color : str, default '#FF0000'
|
||||||
Edge color when doesn't meet the node condition.
|
Edge color when doesn't meet the node condition.
|
||||||
condition_node_params : dict (optional)
|
condition_node_params : dict (optional)
|
||||||
condition node configuration,
|
Condition node configuration for for graphviz. Example:
|
||||||
{'shape':'box',
|
|
||||||
'style':'filled,rounded',
|
.. code-block:: python
|
||||||
'fillcolor':'#78bceb'}
|
|
||||||
|
{'shape': 'box',
|
||||||
|
'style': 'filled,rounded',
|
||||||
|
'fillcolor': '#78bceb'}
|
||||||
|
|
||||||
leaf_node_params : dict (optional)
|
leaf_node_params : dict (optional)
|
||||||
leaf node configuration
|
Leaf node configuration for graphviz. Example:
|
||||||
{'shape':'box',
|
|
||||||
'style':'filled',
|
|
||||||
'fillcolor':'#e48038'}
|
|
||||||
|
|
||||||
kwargs :
|
.. code-block:: python
|
||||||
Other keywords passed to graphviz graph_attr
|
|
||||||
|
{'shape': 'box',
|
||||||
|
'style': 'filled',
|
||||||
|
'fillcolor': '#e48038'}
|
||||||
|
|
||||||
|
kwargs : Other keywords passed to graphviz graph_attr, E.g.:
|
||||||
|
``graph [ {key} = {value} ]``
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
ax : matplotlib Axes
|
graph: graphviz.Source
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if condition_node_params is None:
|
|
||||||
condition_node_params = {}
|
|
||||||
if leaf_node_params is None:
|
|
||||||
leaf_node_params = {}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from graphviz import Digraph
|
from graphviz import Source
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('You must install graphviz to plot tree')
|
raise ImportError('You must install graphviz to plot tree')
|
||||||
|
|
||||||
if not isinstance(booster, (Booster, XGBModel)):
|
|
||||||
raise ValueError('booster must be Booster or XGBModel instance')
|
|
||||||
|
|
||||||
if isinstance(booster, XGBModel):
|
if isinstance(booster, XGBModel):
|
||||||
booster = booster.get_booster()
|
booster = booster.get_booster()
|
||||||
|
|
||||||
tree = booster.get_dump(fmap=fmap)[num_trees]
|
# squash everything back into kwargs again for compatibility
|
||||||
tree = tree.split()
|
parameters = 'dot'
|
||||||
|
extra = {}
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
extra[key] = value
|
||||||
|
|
||||||
kwargs = kwargs.copy()
|
if rankdir is not None:
|
||||||
kwargs.update({'rankdir': rankdir})
|
kwargs['graph_attrs'] = {}
|
||||||
graph = Digraph(graph_attr=kwargs)
|
kwargs['graph_attrs']['rankdir'] = rankdir
|
||||||
|
for key, value in extra.items():
|
||||||
for i, text in enumerate(tree):
|
if 'graph_attrs' in kwargs.keys():
|
||||||
if text[0].isdigit():
|
kwargs['graph_attrs'][key] = value
|
||||||
node = _parse_node(
|
|
||||||
graph, text, condition_node_params=condition_node_params,
|
|
||||||
leaf_node_params=leaf_node_params)
|
|
||||||
else:
|
else:
|
||||||
if i == 0:
|
kwargs['graph_attrs'] = {}
|
||||||
# 1st string must be node
|
del kwargs[key]
|
||||||
raise ValueError('Unable to parse given string as tree')
|
|
||||||
_parse_edge(graph, node, text, yes_color=yes_color,
|
|
||||||
no_color=no_color)
|
|
||||||
|
|
||||||
return graph
|
if yes_color is not None or no_color is not None:
|
||||||
|
kwargs['edge'] = {}
|
||||||
|
if yes_color is not None:
|
||||||
|
kwargs['edge']['yes_color'] = yes_color
|
||||||
|
if no_color is not None:
|
||||||
|
kwargs['edge']['no_color'] = no_color
|
||||||
|
|
||||||
|
if condition_node_params is not None:
|
||||||
|
kwargs['condition_node_params'] = condition_node_params
|
||||||
|
if leaf_node_params is not None:
|
||||||
|
kwargs['leaf_node_params'] = leaf_node_params
|
||||||
|
|
||||||
|
if kwargs:
|
||||||
|
parameters += ':'
|
||||||
|
parameters += str(kwargs)
|
||||||
|
tree = booster.get_dump(
|
||||||
|
fmap=fmap,
|
||||||
|
dump_format=parameters)[num_trees]
|
||||||
|
g = Source(tree)
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
|
||||||
"""Plot specified tree.
|
"""Plot specified tree.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -252,7 +218,7 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
|||||||
The name of feature map file
|
The name of feature map file
|
||||||
num_trees : int, default 0
|
num_trees : int, default 0
|
||||||
Specify the ordinal number of target tree
|
Specify the ordinal number of target tree
|
||||||
rankdir : str, default "UT"
|
rankdir : str, default "TB"
|
||||||
Passed to graphiz via graph_attr
|
Passed to graphiz via graph_attr
|
||||||
ax : matplotlib Axes, default None
|
ax : matplotlib Axes, default None
|
||||||
Target axes instance. If None, new figure and axes will be created.
|
Target axes instance. If None, new figure and axes will be created.
|
||||||
@ -264,18 +230,17 @@ def plot_tree(booster, fmap='', num_trees=0, rankdir='UT', ax=None, **kwargs):
|
|||||||
ax : matplotlib Axes
|
ax : matplotlib Axes
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import matplotlib.pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
import matplotlib.image as image
|
from matplotlib import image
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError('You must install matplotlib to plot tree')
|
raise ImportError('You must install matplotlib to plot tree')
|
||||||
|
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(1, 1)
|
_, ax = plt.subplots(1, 1)
|
||||||
|
|
||||||
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees,
|
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
|
||||||
rankdir=rankdir, **kwargs)
|
**kwargs)
|
||||||
|
|
||||||
s = BytesIO()
|
s = BytesIO()
|
||||||
s.write(g.pipe(format='png'))
|
s.write(g.pipe(format='png'))
|
||||||
|
|||||||
@ -1033,19 +1033,21 @@ inline void XGBoostDumpModelImpl(
|
|||||||
*out_models = dmlc::BeginPtr(charp_vecs);
|
*out_models = dmlc::BeginPtr(charp_vecs);
|
||||||
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
*len = static_cast<xgboost::bst_ulong>(charp_vecs.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
XGB_DLL int XGBoosterDumpModel(BoosterHandle handle,
|
||||||
const char* fmap,
|
const char* fmap,
|
||||||
int with_stats,
|
int with_stats,
|
||||||
xgboost::bst_ulong* len,
|
xgboost::bst_ulong* len,
|
||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
|
return XGBoosterDumpModelEx(handle, fmap, with_stats, "text", len, out_models);
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
||||||
const char* fmap,
|
const char* fmap,
|
||||||
int with_stats,
|
int with_stats,
|
||||||
const char *format,
|
const char *format,
|
||||||
xgboost::bst_ulong* len,
|
xgboost::bst_ulong* len,
|
||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
FeatureMap featmap;
|
FeatureMap featmap;
|
||||||
@ -1060,23 +1062,24 @@ XGB_DLL int XGBoosterDumpModelEx(BoosterHandle handle,
|
|||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
XGB_DLL int XGBoosterDumpModelWithFeatures(BoosterHandle handle,
|
||||||
int fnum,
|
int fnum,
|
||||||
const char** fname,
|
const char** fname,
|
||||||
const char** ftype,
|
const char** ftype,
|
||||||
int with_stats,
|
int with_stats,
|
||||||
xgboost::bst_ulong* len,
|
xgboost::bst_ulong* len,
|
||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
|
return XGBoosterDumpModelExWithFeatures(handle, fnum, fname, ftype, with_stats,
|
||||||
"text", len, out_models);
|
"text", len, out_models);
|
||||||
}
|
}
|
||||||
|
|
||||||
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
XGB_DLL int XGBoosterDumpModelExWithFeatures(BoosterHandle handle,
|
||||||
int fnum,
|
int fnum,
|
||||||
const char** fname,
|
const char** fname,
|
||||||
const char** ftype,
|
const char** ftype,
|
||||||
int with_stats,
|
int with_stats,
|
||||||
const char *format,
|
const char *format,
|
||||||
xgboost::bst_ulong* len,
|
xgboost::bst_ulong* len,
|
||||||
const char*** out_models) {
|
const char*** out_models) {
|
||||||
API_BEGIN();
|
API_BEGIN();
|
||||||
CHECK_HANDLE();
|
CHECK_HANDLE();
|
||||||
FeatureMap featmap;
|
FeatureMap featmap;
|
||||||
|
|||||||
@ -10,7 +10,7 @@
|
|||||||
|
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
#include <nvToolsExt.h>
|
#include <nvToolsExt.h>
|
||||||
#endif
|
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
namespace common {
|
namespace common {
|
||||||
@ -98,7 +98,7 @@ struct Monitor {
|
|||||||
stats.timer.Start();
|
stats.timer.Start();
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
stats.nvtx_id = nvtxRangeStartA(name.c_str());
|
stats.nvtx_id = nvtxRangeStartA(name.c_str());
|
||||||
#endif
|
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
void StopCuda(const std::string &name) {
|
void StopCuda(const std::string &name) {
|
||||||
@ -108,7 +108,7 @@ struct Monitor {
|
|||||||
stats.count++;
|
stats.count++;
|
||||||
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
#if defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
nvtxRangeEnd(stats.nvtx_id);
|
nvtxRangeEnd(stats.nvtx_id);
|
||||||
#endif
|
#endif // defined(XGBOOST_USE_NVTX) && defined(__CUDACC__)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,14 +1,19 @@
|
|||||||
/*!
|
/*!
|
||||||
* Copyright 2015 by Contributors
|
* Copyright 2015-2019 by Contributors
|
||||||
* \file tree_model.cc
|
* \file tree_model.cc
|
||||||
* \brief model structure for tree
|
* \brief model structure for tree
|
||||||
*/
|
*/
|
||||||
|
#include <dmlc/registry.h>
|
||||||
|
#include <dmlc/json.h>
|
||||||
|
|
||||||
#include <xgboost/tree_model.h>
|
#include <xgboost/tree_model.h>
|
||||||
|
#include <xgboost/logging.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
#include "./param.h"
|
|
||||||
|
#include "param.h"
|
||||||
|
|
||||||
namespace xgboost {
|
namespace xgboost {
|
||||||
// register tree parameter
|
// register tree parameter
|
||||||
@ -17,158 +22,602 @@ DMLC_REGISTER_PARAMETER(TreeParam);
|
|||||||
namespace tree {
|
namespace tree {
|
||||||
DMLC_REGISTER_PARAMETER(TrainParam);
|
DMLC_REGISTER_PARAMETER(TrainParam);
|
||||||
}
|
}
|
||||||
// internal function to dump regression tree to text
|
|
||||||
void DumpRegTree(std::stringstream& fo, // NOLINT(*)
|
/*!
|
||||||
const RegTree& tree,
|
* \brief Base class for dump model implementation, modeling closely after code generator.
|
||||||
const FeatureMap& fmap,
|
*/
|
||||||
int nid, int depth, int add_comma,
|
class TreeGenerator {
|
||||||
bool with_stats, std::string format) {
|
protected:
|
||||||
int float_max_precision = std::numeric_limits<bst_float>::max_digits10;
|
static int32_t constexpr kFloatMaxPrecision =
|
||||||
if (format == "json") {
|
std::numeric_limits<bst_float>::max_digits10;
|
||||||
if (add_comma) {
|
FeatureMap const& fmap_;
|
||||||
fo << ",";
|
std::stringstream ss_;
|
||||||
}
|
bool const with_stats_;
|
||||||
if (depth != 0) {
|
|
||||||
fo << std::endl;
|
template <typename Float>
|
||||||
}
|
static std::string ToStr(Float value) {
|
||||||
for (int i = 0; i < depth + 1; ++i) {
|
static_assert(std::is_floating_point<Float>::value,
|
||||||
fo << " ";
|
"Use std::to_string instead for non-floating point values.");
|
||||||
}
|
std::stringstream ss;
|
||||||
} else {
|
ss << std::setprecision(kFloatMaxPrecision) << value;
|
||||||
for (int i = 0; i < depth; ++i) {
|
return ss.str();
|
||||||
fo << '\t';
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (tree[nid].IsLeaf()) {
|
|
||||||
if (format == "json") {
|
static std::string Tabs(uint32_t n) {
|
||||||
fo << "{ \"nodeid\": " << nid
|
std::string res;
|
||||||
<< ", \"leaf\": " << std::setprecision(float_max_precision) << tree[nid].LeafValue();
|
for (uint32_t i = 0; i < n; ++i) {
|
||||||
if (with_stats) {
|
res += '\t';
|
||||||
fo << ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
|
|
||||||
}
|
|
||||||
fo << " }";
|
|
||||||
} else {
|
|
||||||
fo << nid << ":leaf=" << std::setprecision(float_max_precision) << tree[nid].LeafValue();
|
|
||||||
if (with_stats) {
|
|
||||||
fo << ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
|
|
||||||
}
|
|
||||||
fo << '\n';
|
|
||||||
}
|
}
|
||||||
} else {
|
return res;
|
||||||
// right then left,
|
}
|
||||||
bst_float cond = tree[nid].SplitCond();
|
/* \brief Find the first occurance of key in input and replace it with corresponding
|
||||||
const unsigned split_index = tree[nid].SplitIndex();
|
* value.
|
||||||
if (split_index < fmap.Size()) {
|
*/
|
||||||
switch (fmap.type(split_index)) {
|
static std::string Match(std::string const& input,
|
||||||
|
std::map<std::string, std::string> const& replacements) {
|
||||||
|
std::string result = input;
|
||||||
|
for (auto const& kv : replacements) {
|
||||||
|
auto pos = result.find(kv.first);
|
||||||
|
CHECK_NE(pos, std::string::npos);
|
||||||
|
result.replace(pos, kv.first.length(), kv.second);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
virtual std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
virtual std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
virtual std::string NodeStat(RegTree const& tree, int32_t nid) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||||
|
|
||||||
|
virtual std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) {
|
||||||
|
auto const split_index = tree[nid].SplitIndex();
|
||||||
|
std::string result;
|
||||||
|
if (split_index < fmap_.Size()) {
|
||||||
|
switch (fmap_.type(split_index)) {
|
||||||
case FeatureMap::kIndicator: {
|
case FeatureMap::kIndicator: {
|
||||||
int nyes = tree[nid].DefaultLeft() ?
|
result = this->Indicator(tree, nid, depth);
|
||||||
tree[nid].RightChild() : tree[nid].LeftChild();
|
|
||||||
if (format == "json") {
|
|
||||||
fo << "{ \"nodeid\": " << nid
|
|
||||||
<< ", \"depth\": " << depth
|
|
||||||
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
|
|
||||||
<< ", \"yes\": " << nyes
|
|
||||||
<< ", \"no\": " << tree[nid].DefaultChild();
|
|
||||||
} else {
|
|
||||||
fo << nid << ":[" << fmap.Name(split_index) << "] yes=" << nyes
|
|
||||||
<< ",no=" << tree[nid].DefaultChild();
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FeatureMap::kInteger: {
|
case FeatureMap::kInteger: {
|
||||||
const bst_float floored = std::floor(cond);
|
result = this->Integer(tree, nid, depth);
|
||||||
const int integer_threshold
|
|
||||||
= (floored == cond) ? static_cast<int>(floored)
|
|
||||||
: static_cast<int>(floored) + 1;
|
|
||||||
if (format == "json") {
|
|
||||||
fo << "{ \"nodeid\": " << nid
|
|
||||||
<< ", \"depth\": " << depth
|
|
||||||
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
|
|
||||||
<< ", \"split_condition\": " << integer_threshold
|
|
||||||
<< ", \"yes\": " << tree[nid].LeftChild()
|
|
||||||
<< ", \"no\": " << tree[nid].RightChild()
|
|
||||||
<< ", \"missing\": " << tree[nid].DefaultChild();
|
|
||||||
} else {
|
|
||||||
fo << nid << ":[" << fmap.Name(split_index) << "<"
|
|
||||||
<< integer_threshold
|
|
||||||
<< "] yes=" << tree[nid].LeftChild()
|
|
||||||
<< ",no=" << tree[nid].RightChild()
|
|
||||||
<< ",missing=" << tree[nid].DefaultChild();
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FeatureMap::kFloat:
|
case FeatureMap::kFloat:
|
||||||
case FeatureMap::kQuantitive: {
|
case FeatureMap::kQuantitive: {
|
||||||
if (format == "json") {
|
result = this->Quantitive(tree, nid, depth);
|
||||||
fo << "{ \"nodeid\": " << nid
|
|
||||||
<< ", \"depth\": " << depth
|
|
||||||
<< ", \"split\": \"" << fmap.Name(split_index) << "\""
|
|
||||||
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
|
|
||||||
<< ", \"yes\": " << tree[nid].LeftChild()
|
|
||||||
<< ", \"no\": " << tree[nid].RightChild()
|
|
||||||
<< ", \"missing\": " << tree[nid].DefaultChild();
|
|
||||||
} else {
|
|
||||||
fo << nid << ":[" << fmap.Name(split_index)
|
|
||||||
<< "<" << std::setprecision(float_max_precision) << cond
|
|
||||||
<< "] yes=" << tree[nid].LeftChild()
|
|
||||||
<< ",no=" << tree[nid].RightChild()
|
|
||||||
<< ",missing=" << tree[nid].DefaultChild();
|
|
||||||
}
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default: LOG(FATAL) << "unknown fmap type";
|
default:
|
||||||
}
|
LOG(FATAL) << "Unknown feature map type.";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (format == "json") {
|
result = this->PlainNode(tree, nid, depth);
|
||||||
fo << "{ \"nodeid\": " << nid
|
}
|
||||||
<< ", \"depth\": " << depth
|
return result;
|
||||||
<< ", \"split\": " << split_index
|
}
|
||||||
<< ", \"split_condition\": " << std::setprecision(float_max_precision) << cond
|
|
||||||
<< ", \"yes\": " << tree[nid].LeftChild()
|
virtual std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||||
<< ", \"no\": " << tree[nid].RightChild()
|
virtual std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) = 0;
|
||||||
<< ", \"missing\": " << tree[nid].DefaultChild();
|
|
||||||
} else {
|
public:
|
||||||
fo << nid << ":[f" << split_index << "<"<< std::setprecision(float_max_precision) << cond
|
TreeGenerator(FeatureMap const& _fmap, bool with_stats) :
|
||||||
<< "] yes=" << tree[nid].LeftChild()
|
fmap_{_fmap}, with_stats_{with_stats} {}
|
||||||
<< ",no=" << tree[nid].RightChild()
|
virtual ~TreeGenerator() = default;
|
||||||
<< ",missing=" << tree[nid].DefaultChild();
|
|
||||||
|
virtual void BuildTree(RegTree const& tree) {
|
||||||
|
ss_ << this->BuildTree(tree, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Str() const {
|
||||||
|
return ss_.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
static TreeGenerator* Create(std::string const& attrs, FeatureMap const& fmap,
|
||||||
|
bool with_stats);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct TreeGenReg : public dmlc::FunctionRegEntryBase<
|
||||||
|
TreeGenReg,
|
||||||
|
std::function<TreeGenerator* (
|
||||||
|
FeatureMap const& fmap, std::string attrs, bool with_stats)> > {
|
||||||
|
};
|
||||||
|
} // namespace xgboost
|
||||||
|
|
||||||
|
|
||||||
|
namespace dmlc {
|
||||||
|
DMLC_REGISTRY_ENABLE(::xgboost::TreeGenReg);
|
||||||
|
} // namespace dmlc
|
||||||
|
|
||||||
|
namespace xgboost {
|
||||||
|
|
||||||
|
TreeGenerator* TreeGenerator::Create(std::string const& attrs, FeatureMap const& fmap,
|
||||||
|
bool with_stats) {
|
||||||
|
auto pos = attrs.find(':');
|
||||||
|
std::string name;
|
||||||
|
std::string params;
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
name = attrs.substr(0, pos);
|
||||||
|
params = attrs.substr(pos+1, attrs.length() - pos - 1);
|
||||||
|
// Eliminate all occurances of single quote string.
|
||||||
|
size_t pos = std::string::npos;
|
||||||
|
while ((pos = params.find('\'')) != std::string::npos) {
|
||||||
|
params.replace(pos, 1, "\"");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
name = attrs;
|
||||||
|
}
|
||||||
|
auto *e = ::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->Find(name);
|
||||||
|
if (e == nullptr) {
|
||||||
|
LOG(FATAL) << "Unknown Model Builder:" << name;
|
||||||
|
}
|
||||||
|
auto p_io_builder = (e->body)(fmap, params, with_stats);
|
||||||
|
return p_io_builder;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define XGBOOST_REGISTER_TREE_IO(UniqueId, Name) \
|
||||||
|
static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeGenReg& \
|
||||||
|
__make_ ## TreeGenReg ## _ ## UniqueId ## __ = \
|
||||||
|
::dmlc::Registry< ::xgboost::TreeGenReg>::Get()->__REGISTER__(Name)
|
||||||
|
|
||||||
|
|
||||||
|
class TextGenerator : public TreeGenerator {
|
||||||
|
using SuperT = TreeGenerator;
|
||||||
|
|
||||||
|
public:
|
||||||
|
TextGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
|
||||||
|
TreeGenerator(fmap, with_stats) {}
|
||||||
|
|
||||||
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string kLeafTemplate = "{tabs}{nid}:leaf={leaf}{stats}";
|
||||||
|
static std::string kStatTemplate = ",cover={cover}";
|
||||||
|
std::string result = SuperT::Match(
|
||||||
|
kLeafTemplate,
|
||||||
|
{{"{tabs}", SuperT::Tabs(depth)},
|
||||||
|
{"{nid}", std::to_string(nid)},
|
||||||
|
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
|
||||||
|
{"{stats}", with_stats_ ?
|
||||||
|
SuperT::Match(kStatTemplate,
|
||||||
|
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kIndicatorTemplate = "{nid}:[{fname}] yes={yes},no={no}";
|
||||||
|
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||||
|
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||||
|
auto split_index = tree[nid].SplitIndex();
|
||||||
|
std::string result = SuperT::Match(
|
||||||
|
kIndicatorTemplate,
|
||||||
|
{{"{nid}", std::to_string(nid)},
|
||||||
|
{"{fname}", fmap_.Name(split_index)},
|
||||||
|
{"{yes}", std::to_string(nyes)},
|
||||||
|
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SplitNodeImpl(
|
||||||
|
RegTree const& tree, int32_t nid, std::string const& template_str,
|
||||||
|
std::string cond, uint32_t depth) {
|
||||||
|
auto split_index = tree[nid].SplitIndex();
|
||||||
|
std::string const result = SuperT::Match(
|
||||||
|
template_str,
|
||||||
|
{{"{tabs}", SuperT::Tabs(depth)},
|
||||||
|
{"{nid}", std::to_string(nid)},
|
||||||
|
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
|
||||||
|
std::to_string(split_index)},
|
||||||
|
{"{cond}", cond},
|
||||||
|
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||||
|
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||||
|
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kIntegerTemplate =
|
||||||
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
|
auto cond = tree[nid].SplitCond();
|
||||||
|
const bst_float floored = std::floor(cond);
|
||||||
|
const int32_t integer_threshold
|
||||||
|
= (floored == cond) ? static_cast<int>(floored)
|
||||||
|
: static_cast<int>(floored) + 1;
|
||||||
|
return SplitNodeImpl(tree, nid, kIntegerTemplate,
|
||||||
|
std::to_string(integer_threshold), depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kQuantitiveTemplate =
|
||||||
|
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
|
auto cond = tree[nid].SplitCond();
|
||||||
|
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
auto cond = tree[nid].SplitCond();
|
||||||
|
static std::string const kNodeTemplate =
|
||||||
|
"{tabs}{nid}:[f{fname}<{cond}] yes={left},no={right},missing={missing}";
|
||||||
|
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string NodeStat(RegTree const& tree, int32_t nid) override {
|
||||||
|
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
|
||||||
|
std::string const result = SuperT::Match(
|
||||||
|
kStatTemplate,
|
||||||
|
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
|
||||||
|
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
if (tree[nid].IsLeaf()) {
|
||||||
|
return this->LeafNode(tree, nid, depth);
|
||||||
|
}
|
||||||
|
static std::string const kNodeTemplate = "{parent}{stat}\n{left}\n{right}";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kNodeTemplate,
|
||||||
|
{{"{parent}", this->SplitNode(tree, nid, depth)},
|
||||||
|
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
|
||||||
|
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
||||||
|
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildTree(RegTree const& tree) override {
|
||||||
|
static std::string const& kTreeTemplate = "{nodes}\n";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kTreeTemplate,
|
||||||
|
{{"{nodes}", this->BuildTree(tree, 0, 0)}});
|
||||||
|
ss_ << result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_IO(TextGenerator, "text")
|
||||||
|
.describe("Dump text representation of tree")
|
||||||
|
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
|
||||||
|
return new TextGenerator(fmap, attrs, with_stats);
|
||||||
|
});
|
||||||
|
|
||||||
|
class JsonGenerator : public TreeGenerator {
|
||||||
|
using SuperT = TreeGenerator;
|
||||||
|
|
||||||
|
public:
|
||||||
|
JsonGenerator(FeatureMap const& fmap, std::string attrs, bool with_stats) :
|
||||||
|
TreeGenerator(fmap, with_stats) {}
|
||||||
|
|
||||||
|
std::string Indent(uint32_t depth) {
|
||||||
|
std::string result;
|
||||||
|
for (uint32_t i = 0; i < depth + 1; ++i) {
|
||||||
|
result += " ";
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kLeafTemplate =
|
||||||
|
R"L({ "nodeid": {nid}, "leaf": {leaf} {stat}})L";
|
||||||
|
static std::string const kStatTemplate =
|
||||||
|
R"S(, "cover": {sum_hess} )S";
|
||||||
|
std::string result = SuperT::Match(
|
||||||
|
kLeafTemplate,
|
||||||
|
{{"{nid}", std::to_string(nid)},
|
||||||
|
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
|
||||||
|
{"{stat}", with_stats_ ? SuperT::Match(
|
||||||
|
kStatTemplate,
|
||||||
|
{{"{sum_hess}",
|
||||||
|
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Indicator(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
int32_t nyes = tree[nid].DefaultLeft() ?
|
||||||
|
tree[nid].RightChild() : tree[nid].LeftChild();
|
||||||
|
static std::string const kIndicatorTemplate =
|
||||||
|
R"ID( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", "yes": {yes}, "no": {no}})ID";
|
||||||
|
auto split_index = tree[nid].SplitIndex();
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kIndicatorTemplate,
|
||||||
|
{{"{nid}", std::to_string(nid)},
|
||||||
|
{"{depth}", std::to_string(depth)},
|
||||||
|
{"{fname}", fmap_.Name(split_index)},
|
||||||
|
{"{yes}", std::to_string(nyes)},
|
||||||
|
{"{no}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SplitNodeImpl(RegTree const& tree, int32_t nid,
|
||||||
|
std::string const& template_str, std::string cond, uint32_t depth) {
|
||||||
|
auto split_index = tree[nid].SplitIndex();
|
||||||
|
std::string const result = SuperT::Match(
|
||||||
|
template_str,
|
||||||
|
{{"{nid}", std::to_string(nid)},
|
||||||
|
{"{depth}", std::to_string(depth)},
|
||||||
|
{"{fname}", split_index < fmap_.Size() ? fmap_.Name(split_index) :
|
||||||
|
std::to_string(split_index)},
|
||||||
|
{"{cond}", cond},
|
||||||
|
{"{left}", std::to_string(tree[nid].LeftChild())},
|
||||||
|
{"{right}", std::to_string(tree[nid].RightChild())},
|
||||||
|
{"{missing}", std::to_string(tree[nid].DefaultChild())}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Integer(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
auto cond = tree[nid].SplitCond();
|
||||||
|
const bst_float floored = std::floor(cond);
|
||||||
|
const int32_t integer_threshold
|
||||||
|
= (floored == cond) ? static_cast<int32_t>(floored)
|
||||||
|
: static_cast<int32_t>(floored) + 1;
|
||||||
|
static std::string const kIntegerTemplate =
|
||||||
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||||
|
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||||
|
R"I("missing": {missing})I";
|
||||||
|
return SplitNodeImpl(tree, nid, kIntegerTemplate,
|
||||||
|
std::to_string(integer_threshold), depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string Quantitive(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kQuantitiveTemplate =
|
||||||
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
|
||||||
|
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||||
|
R"I("missing": {missing})I";
|
||||||
|
bst_float cond = tree[nid].SplitCond();
|
||||||
|
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
auto cond = tree[nid].SplitCond();
|
||||||
|
static std::string const kNodeTemplate =
|
||||||
|
R"I( "nodeid": {nid}, "depth": {depth}, "split": {fname}, )I"
|
||||||
|
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
|
||||||
|
R"I("missing": {missing})I";
|
||||||
|
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string NodeStat(RegTree const& tree, int32_t nid) override {
|
||||||
|
static std::string kStatTemplate =
|
||||||
|
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kStatTemplate,
|
||||||
|
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
|
||||||
|
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string SplitNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
std::string properties = SuperT::SplitNode(tree, nid, depth);
|
||||||
|
static std::string const kSplitNodeTemplate =
|
||||||
|
"{{properties} {stat}, \"children\": [{left}, {right}\n{indent}]}";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kSplitNodeTemplate,
|
||||||
|
{{"{properties}", properties},
|
||||||
|
{"{stat}", with_stats_ ? this->NodeStat(tree, nid) : ""},
|
||||||
|
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
||||||
|
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)},
|
||||||
|
{"{indent}", this->Indent(depth)}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kNodeTemplate = "{newline}{indent}{nodes}";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kNodeTemplate,
|
||||||
|
{{"{newline}", depth == 0 ? "" : "\n"},
|
||||||
|
{"{indent}", Indent(depth)},
|
||||||
|
{"{nodes}", tree[nid].IsLeaf() ? this->LeafNode(tree, nid, depth) :
|
||||||
|
this->SplitNode(tree, nid, depth)}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_IO(JsonGenerator, "json")
|
||||||
|
.describe("Dump json representation of tree")
|
||||||
|
.set_body([](FeatureMap const& fmap, std::string const& attrs, bool with_stats) {
|
||||||
|
return new JsonGenerator(fmap, attrs, with_stats);
|
||||||
|
});
|
||||||
|
|
||||||
|
struct GraphvizParam : public dmlc::Parameter<GraphvizParam> {
|
||||||
|
std::string yes_color;
|
||||||
|
std::string no_color;
|
||||||
|
std::string rankdir;
|
||||||
|
std::string condition_node_params;
|
||||||
|
std::string leaf_node_params;
|
||||||
|
std::string graph_attrs;
|
||||||
|
|
||||||
|
DMLC_DECLARE_PARAMETER(GraphvizParam){
|
||||||
|
DMLC_DECLARE_FIELD(yes_color)
|
||||||
|
.set_default("#0000FF")
|
||||||
|
.describe("Edge color when meets the node condition.");
|
||||||
|
DMLC_DECLARE_FIELD(no_color)
|
||||||
|
.set_default("#FF0000")
|
||||||
|
.describe("Edge color when doesn't meet the node condition.");
|
||||||
|
DMLC_DECLARE_FIELD(rankdir)
|
||||||
|
.set_default("TB")
|
||||||
|
.describe("Passed to graphiz via graph_attr.");
|
||||||
|
DMLC_DECLARE_FIELD(condition_node_params)
|
||||||
|
.set_default("")
|
||||||
|
.describe("Conditional node configuration");
|
||||||
|
DMLC_DECLARE_FIELD(leaf_node_params)
|
||||||
|
.set_default("")
|
||||||
|
.describe("Leaf node configuration");
|
||||||
|
DMLC_DECLARE_FIELD(graph_attrs)
|
||||||
|
.set_default("")
|
||||||
|
.describe("Any other extra attributes for graphviz `graph_attr`.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
DMLC_REGISTER_PARAMETER(GraphvizParam);
|
||||||
|
|
||||||
|
class GraphvizGenerator : public TreeGenerator {
|
||||||
|
using SuperT = TreeGenerator;
|
||||||
|
std::stringstream& ss_;
|
||||||
|
GraphvizParam param_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
GraphvizGenerator(FeatureMap const& fmap, std::string const& attrs, bool with_stats) :
|
||||||
|
TreeGenerator(fmap, with_stats), ss_{SuperT::ss_} {
|
||||||
|
param_.InitAllowUnknown(std::map<std::string, std::string>{});
|
||||||
|
using KwArg = std::map<std::string, std::map<std::string, std::string>>;
|
||||||
|
KwArg kwargs;
|
||||||
|
if (attrs.length() != 0) {
|
||||||
|
std::istringstream iss(attrs);
|
||||||
|
try {
|
||||||
|
dmlc::JSONReader reader(&iss);
|
||||||
|
reader.Read(&kwargs);
|
||||||
|
} catch(dmlc::Error const& e) {
|
||||||
|
LOG(FATAL) << "Failed to parse graphviz parameters:\n\t"
|
||||||
|
<< attrs << "\n"
|
||||||
|
<< "With error:\n"
|
||||||
|
<< e.what();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (with_stats) {
|
// This turns out to be tricky, as `dmlc::Parameter::Load(JSONReader*)` doesn't
|
||||||
if (format == "json") {
|
// support loading nested json objects.
|
||||||
fo << ", \"gain\": " << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg
|
if (kwargs.find("condition_node_params") != kwargs.cend()) {
|
||||||
<< ", \"cover\": " << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
|
auto const& cnp = kwargs["condition_node_params"];
|
||||||
} else {
|
for (auto const& kv : cnp) {
|
||||||
fo << ",gain=" << std::setprecision(float_max_precision) << tree.Stat(nid).loss_chg
|
param_.condition_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
|
||||||
<< ",cover=" << std::setprecision(float_max_precision) << tree.Stat(nid).sum_hess;
|
|
||||||
}
|
}
|
||||||
|
kwargs.erase("condition_node_params");
|
||||||
}
|
}
|
||||||
if (format == "json") {
|
if (kwargs.find("leaf_node_params") != kwargs.cend()) {
|
||||||
fo << ", \"children\": [";
|
auto const& lnp = kwargs["leaf_node_params"];
|
||||||
} else {
|
for (auto const& kv : lnp) {
|
||||||
fo << '\n';
|
param_.leaf_node_params += kv.first + '=' + "\"" + kv.second + "\" ";
|
||||||
}
|
|
||||||
DumpRegTree(fo, tree, fmap, tree[nid].LeftChild(), depth + 1, false, with_stats, format);
|
|
||||||
DumpRegTree(fo, tree, fmap, tree[nid].RightChild(), depth + 1, true, with_stats, format);
|
|
||||||
if (format == "json") {
|
|
||||||
fo << std::endl;
|
|
||||||
for (int i = 0; i < depth + 1; ++i) {
|
|
||||||
fo << " ";
|
|
||||||
}
|
}
|
||||||
fo << "]}";
|
kwargs.erase("leaf_node_params");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kwargs.find("edge") != kwargs.cend()) {
|
||||||
|
if (kwargs["edge"].find("yes_color") != kwargs["edge"].cend()) {
|
||||||
|
param_.yes_color = kwargs["edge"]["yes_color"];
|
||||||
|
}
|
||||||
|
if (kwargs["edge"].find("no_color") != kwargs["edge"].cend()) {
|
||||||
|
param_.no_color = kwargs["edge"]["no_color"];
|
||||||
|
}
|
||||||
|
kwargs.erase("edge");
|
||||||
|
}
|
||||||
|
auto const& extra = kwargs["graph_attrs"];
|
||||||
|
static std::string const kGraphTemplate = " graph [ {key}=\"{value}\" ]\n";
|
||||||
|
for (auto const& kv : extra) {
|
||||||
|
param_.graph_attrs += SuperT::Match(kGraphTemplate,
|
||||||
|
{{"{key}", kv.first},
|
||||||
|
{"{value}", kv.second}});
|
||||||
|
}
|
||||||
|
|
||||||
|
kwargs.erase("graph_attrs");
|
||||||
|
if (kwargs.size() != 0) {
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << "The following parameters for graphviz are not recognized:\n";
|
||||||
|
for (auto kv : kwargs) {
|
||||||
|
ss << kv.first << ", ";
|
||||||
|
}
|
||||||
|
LOG(WARNING) << ss.str();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
protected:
|
||||||
|
// Only indicator is different, so we combine all different node types into this
|
||||||
|
// function.
|
||||||
|
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
auto split = tree[nid].SplitIndex();
|
||||||
|
auto cond = tree[nid].SplitCond();
|
||||||
|
static std::string const kNodeTemplate =
|
||||||
|
" {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";
|
||||||
|
|
||||||
|
// Indicator only has fname.
|
||||||
|
bool has_less = (split >= fmap_.Size()) || fmap_.type(split) != FeatureMap::kIndicator;
|
||||||
|
std::string result = SuperT::Match(kNodeTemplate, {
|
||||||
|
{"{nid}", std::to_string(nid)},
|
||||||
|
{"{fname}", split < fmap_.Size() ? fmap_.Name(split) :
|
||||||
|
'f' + std::to_string(split)},
|
||||||
|
{"{<}", has_less ? "<" : ""},
|
||||||
|
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
|
||||||
|
{"{params}", param_.condition_node_params}});
|
||||||
|
|
||||||
|
static std::string const kEdgeTemplate =
|
||||||
|
" {nid} -> {child} [label=\"{is_missing}\" color=\"{color}\"]\n";
|
||||||
|
auto MatchFn = SuperT::Match; // mingw failed to capture protected fn.
|
||||||
|
auto BuildEdge =
|
||||||
|
[&tree, nid, MatchFn, this](int32_t child) {
|
||||||
|
bool is_missing = tree[nid].DefaultChild() == child;
|
||||||
|
std::string buffer = MatchFn(kEdgeTemplate, {
|
||||||
|
{"{nid}", std::to_string(nid)},
|
||||||
|
{"{child}", std::to_string(child)},
|
||||||
|
{"{color}", is_missing ? param_.yes_color : param_.no_color},
|
||||||
|
{"{is_missing}", is_missing ? "yes, missing": "no"}});
|
||||||
|
return buffer;
|
||||||
|
};
|
||||||
|
result += BuildEdge(tree[nid].LeftChild());
|
||||||
|
result += BuildEdge(tree[nid].RightChild());
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
static std::string const kLeafTemplate =
|
||||||
|
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
|
||||||
|
auto result = SuperT::Match(kLeafTemplate, {
|
||||||
|
{"{nid}", std::to_string(nid)},
|
||||||
|
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
|
||||||
|
{"{params}", param_.leaf_node_params}});
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
|
||||||
|
if (tree[nid].IsLeaf()) {
|
||||||
|
return this->LeafNode(tree, nid, depth);
|
||||||
|
}
|
||||||
|
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kNodeTemplate,
|
||||||
|
{{"{parent}", this->PlainNode(tree, nid, depth)},
|
||||||
|
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
|
||||||
|
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void BuildTree(RegTree const& tree) override {
|
||||||
|
static std::string const kTreeTemplate =
|
||||||
|
"digraph {\n"
|
||||||
|
" graph [ rankdir={rankdir} ]\n"
|
||||||
|
"{graph_attrs}\n"
|
||||||
|
"{nodes}}";
|
||||||
|
auto result = SuperT::Match(
|
||||||
|
kTreeTemplate,
|
||||||
|
{{"{rankdir}", param_.rankdir},
|
||||||
|
{"{graph_attrs}", param_.graph_attrs},
|
||||||
|
{"{nodes}", this->BuildTree(tree, 0, 0)}});
|
||||||
|
ss_ << result;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
|
||||||
|
.describe("Dump graphviz representation of tree")
|
||||||
|
.set_body([](FeatureMap const& fmap, std::string attrs, bool with_stats) {
|
||||||
|
return new GraphvizGenerator(fmap, attrs, with_stats);
|
||||||
|
});
|
||||||
|
|
||||||
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
std::string RegTree::DumpModel(const FeatureMap& fmap,
|
||||||
bool with_stats,
|
bool with_stats,
|
||||||
std::string format) const {
|
std::string format) const {
|
||||||
std::stringstream fo("");
|
std::unique_ptr<TreeGenerator> builder {
|
||||||
for (int i = 0; i < param.num_roots; ++i) {
|
TreeGenerator::Create(format, fmap, with_stats)
|
||||||
DumpRegTree(fo, *this, fmap, i, 0, false, with_stats, format);
|
};
|
||||||
|
for (int32_t i = 0; i < param.num_roots; ++i) {
|
||||||
|
builder->BuildTree(*this);
|
||||||
}
|
}
|
||||||
return fo.str();
|
|
||||||
|
std::string result = builder->Str();
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegTree::FillNodeMeanValues() {
|
void RegTree::FillNodeMeanValues() {
|
||||||
size_t num_nodes = this->param.num_nodes;
|
size_t num_nodes = this->param.num_nodes;
|
||||||
if (this->node_mean_values_.size() == num_nodes) {
|
if (this->node_mean_values_.size() == num_nodes) {
|
||||||
|
|||||||
@ -144,7 +144,7 @@ __global__ void CubScanByKeyL1(
|
|||||||
int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1);
|
int previousKey = __shfl_up_sync(0xFFFFFFFF, myKey, 1);
|
||||||
#else
|
#else
|
||||||
int previousKey = __shfl_up(myKey, 1);
|
int previousKey = __shfl_up(myKey, 1);
|
||||||
#endif
|
#endif // (__CUDACC_VER_MAJOR__ >= 9)
|
||||||
// Collectively compute the block-wide exclusive prefix sum
|
// Collectively compute the block-wide exclusive prefix sum
|
||||||
BlockScan(temp_storage)
|
BlockScan(temp_storage)
|
||||||
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());
|
.ExclusiveScan(threadData, threadData, rootPair, AddByKey());
|
||||||
|
|||||||
@ -101,4 +101,121 @@ TEST(Tree, AllocateNode) {
|
|||||||
ASSERT_TRUE(nodes.at(1).IsLeaf());
|
ASSERT_TRUE(nodes.at(1).IsLeaf());
|
||||||
ASSERT_TRUE(nodes.at(2).IsLeaf());
|
ASSERT_TRUE(nodes.at(2).IsLeaf());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
RegTree ConstructTree() {
|
||||||
|
RegTree tree;
|
||||||
|
tree.ExpandNode(
|
||||||
|
/*nid=*/0, /*split_index=*/0, /*split_value=*/0.0f,
|
||||||
|
/*default_left=*/true,
|
||||||
|
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
auto left = tree[0].LeftChild();
|
||||||
|
auto right = tree[0].RightChild();
|
||||||
|
tree.ExpandNode(
|
||||||
|
/*nid=*/left, /*split_index=*/1, /*split_value=*/1.0f,
|
||||||
|
/*default_left=*/false,
|
||||||
|
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
tree.ExpandNode(
|
||||||
|
/*nid=*/right, /*split_index=*/2, /*split_value=*/2.0f,
|
||||||
|
/*default_left=*/false,
|
||||||
|
0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
return tree;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Tree, DumpJson) {
|
||||||
|
auto tree = ConstructTree();
|
||||||
|
FeatureMap fmap;
|
||||||
|
auto str = tree.DumpModel(fmap, true, "json");
|
||||||
|
size_t n_leaves = 0;
|
||||||
|
size_t iter = 0;
|
||||||
|
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||||
|
n_leaves++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_leaves, 4);
|
||||||
|
|
||||||
|
size_t n_conditions = 0;
|
||||||
|
iter = 0;
|
||||||
|
while ((iter = str.find("split_condition", iter + 1)) != std::string::npos) {
|
||||||
|
n_conditions++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_conditions, 3);
|
||||||
|
|
||||||
|
fmap.PushBack(0, "feat_0", "i");
|
||||||
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
|
fmap.PushBack(2, "feat_2", "int");
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, true, "json");
|
||||||
|
ASSERT_NE(str.find(R"("split": "feat_0")"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find(R"("split": "feat_1")"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find(R"("split": "feat_2")"), std::string::npos);
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, false, "json");
|
||||||
|
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Tree, DumpText) {
|
||||||
|
auto tree = ConstructTree();
|
||||||
|
FeatureMap fmap;
|
||||||
|
auto str = tree.DumpModel(fmap, true, "text");
|
||||||
|
size_t n_leaves = 0;
|
||||||
|
size_t iter = 0;
|
||||||
|
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||||
|
n_leaves++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_leaves, 4);
|
||||||
|
|
||||||
|
iter = 0;
|
||||||
|
size_t n_conditions = 0;
|
||||||
|
while ((iter = str.find("gain", iter + 1)) != std::string::npos) {
|
||||||
|
n_conditions++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_conditions, 3);
|
||||||
|
|
||||||
|
ASSERT_NE(str.find("[f0<0]"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find("[f1<1]"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find("[f2<2]"), std::string::npos);
|
||||||
|
|
||||||
|
fmap.PushBack(0, "feat_0", "i");
|
||||||
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
|
fmap.PushBack(2, "feat_2", "int");
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, true, "text");
|
||||||
|
ASSERT_NE(str.find("[feat_0]"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find("[feat_1<1]"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find("[feat_2<2]"), std::string::npos);
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, false, "text");
|
||||||
|
ASSERT_EQ(str.find("cover"), std::string::npos);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(Tree, DumpDot) {
|
||||||
|
auto tree = ConstructTree();
|
||||||
|
FeatureMap fmap;
|
||||||
|
auto str = tree.DumpModel(fmap, true, "dot");
|
||||||
|
|
||||||
|
size_t n_leaves = 0;
|
||||||
|
size_t iter = 0;
|
||||||
|
while ((iter = str.find("leaf", iter + 1)) != std::string::npos) {
|
||||||
|
n_leaves++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_leaves, 4);
|
||||||
|
|
||||||
|
size_t n_edges = 0;
|
||||||
|
iter = 0;
|
||||||
|
while ((iter = str.find("->", iter + 1)) != std::string::npos) {
|
||||||
|
n_edges++;
|
||||||
|
}
|
||||||
|
ASSERT_EQ(n_edges, 6);
|
||||||
|
|
||||||
|
fmap.PushBack(0, "feat_0", "i");
|
||||||
|
fmap.PushBack(1, "feat_1", "q");
|
||||||
|
fmap.PushBack(2, "feat_2", "int");
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, true, "dot");
|
||||||
|
ASSERT_NE(str.find(R"("feat_0")"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find(R"(feat_1<1)"), std::string::npos);
|
||||||
|
ASSERT_NE(str.find(R"(feat_2<2)"), std::string::npos);
|
||||||
|
|
||||||
|
str = tree.DumpModel(fmap, true, R"(dot:{"graph_attrs": {"bgcolor": "#FFFF00"}})");
|
||||||
|
ASSERT_NE(str.find(R"(graph [ bgcolor="#FFFF00" ])"), std::string::npos);
|
||||||
|
}
|
||||||
} // namespace xgboost
|
} // namespace xgboost
|
||||||
|
|||||||
@ -10,7 +10,7 @@ try:
|
|||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from graphviz import Digraph
|
from graphviz import Source
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -57,7 +57,7 @@ class TestPlotting(unittest.TestCase):
|
|||||||
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
assert ax.patches[3].get_facecolor() == (0, 0, 1.0, 1.0) # blue
|
||||||
|
|
||||||
g = xgb.to_graphviz(bst2, num_trees=0)
|
g = xgb.to_graphviz(bst2, num_trees=0)
|
||||||
assert isinstance(g, Digraph)
|
assert isinstance(g, Source)
|
||||||
|
|
||||||
ax = xgb.plot_tree(bst2, num_trees=0)
|
ax = xgb.plot_tree(bst2, num_trees=0)
|
||||||
assert isinstance(ax, Axes)
|
assert isinstance(ax, Axes)
|
||||||
|
|||||||
@ -87,7 +87,6 @@ class TestSHAP(unittest.TestCase):
|
|||||||
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
|
r_exp = r"([0-9]+):\[f([0-9]+)<([0-9\.e-]+)\] yes=([0-9]+),no=([0-9]+).*cover=([0-9e\.]+)"
|
||||||
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
|
r_exp_leaf = r"([0-9]+):leaf=([0-9\.e-]+),cover=([0-9e\.]+)"
|
||||||
for tree in model.get_dump(with_stats=True):
|
for tree in model.get_dump(with_stats=True):
|
||||||
|
|
||||||
lines = list(tree.splitlines())
|
lines = list(tree.splitlines())
|
||||||
trees.append([None for i in range(len(lines))])
|
trees.append([None for i in range(len(lines))])
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
|||||||
@ -352,7 +352,7 @@ def test_sklearn_plotting():
|
|||||||
matplotlib.use('Agg')
|
matplotlib.use('Agg')
|
||||||
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from graphviz import Digraph
|
from graphviz import Source
|
||||||
|
|
||||||
ax = xgb.plot_importance(classifier)
|
ax = xgb.plot_importance(classifier)
|
||||||
assert isinstance(ax, Axes)
|
assert isinstance(ax, Axes)
|
||||||
@ -362,7 +362,7 @@ def test_sklearn_plotting():
|
|||||||
assert len(ax.patches) == 4
|
assert len(ax.patches) == 4
|
||||||
|
|
||||||
g = xgb.to_graphviz(classifier, num_trees=0)
|
g = xgb.to_graphviz(classifier, num_trees=0)
|
||||||
assert isinstance(g, Digraph)
|
assert isinstance(g, Source)
|
||||||
|
|
||||||
ax = xgb.plot_tree(classifier, num_trees=0)
|
ax = xgb.plot_tree(classifier, num_trees=0)
|
||||||
assert isinstance(ax, Axes)
|
assert isinstance(ax, Axes)
|
||||||
@ -641,7 +641,8 @@ def test_XGBClassifier_resume():
|
|||||||
|
|
||||||
X, Y = load_breast_cancer(return_X_y=True)
|
X, Y = load_breast_cancer(return_X_y=True)
|
||||||
|
|
||||||
model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
model1 = xgb.XGBClassifier(
|
||||||
|
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||||
model1.fit(X, Y)
|
model1.fit(X, Y)
|
||||||
|
|
||||||
pred1 = model1.predict(X)
|
pred1 = model1.predict(X)
|
||||||
@ -649,7 +650,8 @@ def test_XGBClassifier_resume():
|
|||||||
|
|
||||||
# file name of stored xgb model
|
# file name of stored xgb model
|
||||||
model1.save_model(model1_path)
|
model1.save_model(model1_path)
|
||||||
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
model2 = xgb.XGBClassifier(
|
||||||
|
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||||
model2.fit(X, Y, xgb_model=model1_path)
|
model2.fit(X, Y, xgb_model=model1_path)
|
||||||
|
|
||||||
pred2 = model2.predict(X)
|
pred2 = model2.predict(X)
|
||||||
@ -660,7 +662,8 @@ def test_XGBClassifier_resume():
|
|||||||
|
|
||||||
# file name of 'Booster' instance Xgb model
|
# file name of 'Booster' instance Xgb model
|
||||||
model1.get_booster().save_model(model1_booster_path)
|
model1.get_booster().save_model(model1_booster_path)
|
||||||
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
model2 = xgb.XGBClassifier(
|
||||||
|
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||||
model2.fit(X, Y, xgb_model=model1_booster_path)
|
model2.fit(X, Y, xgb_model=model1_booster_path)
|
||||||
|
|
||||||
pred2 = model2.predict(X)
|
pred2 = model2.predict(X)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user