Fix GPU L1 error. (#8749)
This commit is contained in:
@@ -337,13 +337,21 @@ class TestGPUPredict:
|
||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||
def test_predict_leaf_gbtree(self, param, dataset):
|
||||
# Unsupported for random forest
|
||||
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||
return
|
||||
|
||||
param['booster'] = 'gbtree'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
@given(predict_parameter_strategy, tm.dataset_strategy)
|
||||
@settings(deadline=None, max_examples=20, print_blob=True)
|
||||
def test_predict_leaf_dart(self, param, dataset):
|
||||
def test_predict_leaf_dart(self, param: dict, dataset: tm.TestDataset) -> None:
|
||||
# Unsupported for random forest
|
||||
if param.get("num_parallel_tree", 1) > 1 and dataset.name.endswith("-l1"):
|
||||
return
|
||||
|
||||
param['booster'] = 'dart'
|
||||
param['tree_method'] = 'gpu_hist'
|
||||
self.run_predict_leaf_booster(param, 10, dataset)
|
||||
|
||||
Reference in New Issue
Block a user