Implement GPU predict leaf. (#6187)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import json
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
@@ -9,6 +10,7 @@ from hypothesis import given, strategies, assume, settings, note
|
||||
sys.path.append("tests/python")
|
||||
import testing as tm
|
||||
from test_predict import run_threaded_predict # noqa
|
||||
from test_predict import run_predict_leaf # noqa
|
||||
|
||||
rng = np.random.RandomState(1994)
|
||||
|
||||
@@ -18,6 +20,11 @@ shap_parameter_strategy = strategies.fixed_dictionaries({
|
||||
'num_parallel_tree': strategies.sampled_from([1, 10]),
|
||||
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
|
||||
|
||||
predict_parameter_strategy = strategies.fixed_dictionaries({
|
||||
'max_depth': strategies.integers(1, 8),
|
||||
'num_parallel_tree': strategies.sampled_from([1, 4]),
|
||||
})
|
||||
|
||||
|
||||
class TestGPUPredict(unittest.TestCase):
|
||||
def test_predict(self):
|
||||
@@ -223,3 +230,34 @@ class TestGPUPredict(unittest.TestCase):
|
||||
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
|
||||
margin,
|
||||
1e-3, 1e-3)
|
||||
|
||||
def test_predict_leaf_basic(self):
|
||||
gpu_leaf = run_predict_leaf('gpu_predictor')
|
||||
cpu_leaf = run_predict_leaf('cpu_predictor')
|
||||
np.testing.assert_equal(gpu_leaf, cpu_leaf)
|
||||
|
||||
def run_predict_leaf_booster(self, param, num_rounds, dataset):
|
||||
param = dataset.set_params(param)
|
||||
m = dataset.get_dmat()
|
||||
booster = xgb.train(param, dtrain=dataset.get_dmat(), num_boost_round=num_rounds)
|
||||
booster.set_param({'predictor': 'cpu_predictor'})
|
||||
cpu_leaf = booster.predict(m, pred_leaf=True)
|
||||
|
||||
booster.set_param({'predictor': 'gpu_predictor'})
|
||||
gpu_leaf = booster.predict(m, pred_leaf=True)
|
||||
|
||||
np.testing.assert_equal(cpu_leaf, gpu_leaf)
|
||||
|
||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_predict_leaf_gbtree(self, param, dataset):
|
||||
param['booster'] = 'gbtree'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||
@settings(deadline=None)
|
||||
def test_predict_leaf_dart(self, param, dataset):
|
||||
param['booster'] = 'dart'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
Reference in New Issue
Block a user