Implement tree model dump with code generator. (#4602)
* Implement tree model dump with a code generator. * Split up generators. * Implement graphviz generator. * Use pattern matching. * [Breaking] Return a Source in `to_graphviz` instead of Digraph in Python package. Co-Authored-By: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
This commit is contained in:
@@ -352,7 +352,7 @@ def test_sklearn_plotting():
|
||||
matplotlib.use('Agg')
|
||||
|
||||
from matplotlib.axes import Axes
|
||||
from graphviz import Digraph
|
||||
from graphviz import Source
|
||||
|
||||
ax = xgb.plot_importance(classifier)
|
||||
assert isinstance(ax, Axes)
|
||||
@@ -362,7 +362,7 @@ def test_sklearn_plotting():
|
||||
assert len(ax.patches) == 4
|
||||
|
||||
g = xgb.to_graphviz(classifier, num_trees=0)
|
||||
assert isinstance(g, Digraph)
|
||||
assert isinstance(g, Source)
|
||||
|
||||
ax = xgb.plot_tree(classifier, num_trees=0)
|
||||
assert isinstance(ax, Axes)
|
||||
@@ -641,7 +641,8 @@ def test_XGBClassifier_resume():
|
||||
|
||||
X, Y = load_breast_cancer(return_X_y=True)
|
||||
|
||||
model1 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
||||
model1 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
model1.fit(X, Y)
|
||||
|
||||
pred1 = model1.predict(X)
|
||||
@@ -649,7 +650,8 @@ def test_XGBClassifier_resume():
|
||||
|
||||
# file name of stored xgb model
|
||||
model1.save_model(model1_path)
|
||||
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
||||
model2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
model2.fit(X, Y, xgb_model=model1_path)
|
||||
|
||||
pred2 = model2.predict(X)
|
||||
@@ -660,7 +662,8 @@ def test_XGBClassifier_resume():
|
||||
|
||||
# file name of 'Booster' instance Xgb model
|
||||
model1.get_booster().save_model(model1_booster_path)
|
||||
model2 = xgb.XGBClassifier(learning_rate=0.3, seed=0, n_estimators=8)
|
||||
model2 = xgb.XGBClassifier(
|
||||
learning_rate=0.3, random_state=0, n_estimators=8)
|
||||
model2.fit(X, Y, xgb_model=model1_booster_path)
|
||||
|
||||
pred2 = model2.predict(X)
|
||||
|
||||
Reference in New Issue
Block a user