Previously, we use `libsvm` as default when format is not specified. However, the dmlc data parser is not particularly robust against errors, and the most common type of error is undefined format. Along with which, we will recommend users to use other data loader instead. We will continue the maintenance of the parsers as it's currently used for many internal tests including federated learning.
32 lines
850 B
Python
32 lines
850 B
Python
"""
|
|
Demo for obtaining leaf index
|
|
=============================
|
|
"""
|
|
import os
|
|
|
|
import xgboost as xgb
|
|
|
|
# load data in do training
|
|
CURRENT_DIR = os.path.dirname(__file__)
|
|
dtrain = xgb.DMatrix(
|
|
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
|
|
)
|
|
dtest = xgb.DMatrix(
|
|
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
|
|
)
|
|
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
|
|
watchlist = [(dtest, "eval"), (dtrain, "train")]
|
|
num_round = 3
|
|
bst = xgb.train(param, dtrain, num_round, watchlist)
|
|
|
|
print("start testing predict the leaf indices")
|
|
# predict using first 2 tree
|
|
leafindex = bst.predict(
|
|
dtest, iteration_range=(0, 2), pred_leaf=True, strict_shape=True
|
|
)
|
|
print(leafindex.shape)
|
|
print(leafindex)
|
|
# predict all trees
|
|
leafindex = bst.predict(dtest, pred_leaf=True)
|
|
print(leafindex.shape)
|