xgboost/demo/guide-python/predict_first_ntree.py
Jiaming Yuan 1f9a57d17b
[Breaking] Require format to be specified in input URI. (#9077)
Previously, we use `libsvm` as default when format is not specified. However, the dmlc
data parser is not particularly robust against errors, and the most common type of error
is undefined format.

Along with which, we will recommend users to use other data loader instead. We will
continue the maintenance of the parsers as it's currently used for many internal tests
including federated learning.
2023-04-28 19:45:15 +08:00

61 lines
1.9 KiB
Python

"""
Demo for prediction using number of trees
=========================================
"""
import os
import numpy as np
from sklearn.datasets import load_svmlight_file
import xgboost as xgb
CURRENT_DIR = os.path.dirname(__file__)
train = os.path.join(CURRENT_DIR, "../data/agaricus.txt.train")
test = os.path.join(CURRENT_DIR, "../data/agaricus.txt.test")
def native_interface():
# load data in do training
dtrain = xgb.DMatrix(train + "?format=libsvm")
dtest = xgb.DMatrix(test + "?format=libsvm")
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
bst = xgb.train(param, dtrain, num_round, watchlist)
print("start testing prediction from first n trees")
# predict using first 1 tree
label = dtest.get_label()
ypred1 = bst.predict(dtest, iteration_range=(0, 1))
# by default, we predict using all the trees
ypred2 = bst.predict(dtest)
print("error of ypred1=%f" % (np.sum((ypred1 > 0.5) != label) / float(len(label))))
print("error of ypred2=%f" % (np.sum((ypred2 > 0.5) != label) / float(len(label))))
def sklearn_interface():
X_train, y_train = load_svmlight_file(train)
X_test, y_test = load_svmlight_file(test)
clf = xgb.XGBClassifier(n_estimators=3, max_depth=2, eta=1)
clf.fit(X_train, y_train, eval_set=[(X_test, y_test)])
assert clf.n_classes_ == 2
print("start testing prediction from first n trees")
# predict using first 1 tree
ypred1 = clf.predict(X_test, iteration_range=(0, 1))
# by default, we predict using all the trees
ypred2 = clf.predict(X_test)
print(
"error of ypred1=%f" % (np.sum((ypred1 > 0.5) != y_test) / float(len(y_test)))
)
print(
"error of ypred2=%f" % (np.sum((ypred2 > 0.5) != y_test) / float(len(y_test)))
)
if __name__ == "__main__":
native_interface()
sklearn_interface()