GPUTreeShap (#6038)
This commit is contained in:
@@ -131,6 +131,7 @@ class TestDataset:
|
||||
self.metric = metric
|
||||
self.X, self.y = get_dataset()
|
||||
self.w = None
|
||||
self.margin = None
|
||||
|
||||
def set_params(self, params_in):
|
||||
params_in['objective'] = self.objective
|
||||
@@ -140,13 +141,13 @@ class TestDataset:
|
||||
return params_in
|
||||
|
||||
def get_dmat(self):
|
||||
return xgb.DMatrix(self.X, self.y, self.w)
|
||||
return xgb.DMatrix(self.X, self.y, self.w, base_margin=self.margin)
|
||||
|
||||
def get_device_dmat(self):
|
||||
w = None if self.w is None else cp.array(self.w)
|
||||
X = cp.array(self.X, dtype=np.float32)
|
||||
y = cp.array(self.y, dtype=np.float32)
|
||||
return xgb.DeviceQuantileDMatrix(X, y, w)
|
||||
return xgb.DeviceQuantileDMatrix(X, y, w, base_margin=self.margin)
|
||||
|
||||
def get_external_dmat(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -157,7 +158,7 @@ class TestDataset:
|
||||
uri = path + '?format=csv&label_column=0#tmptmp_'
|
||||
# The uri looks like:
|
||||
# 'tmptmp_1234.csv?format=csv&label_column=0#tmptmp_'
|
||||
return xgb.DMatrix(uri, weight=self.w)
|
||||
return xgb.DMatrix(uri, weight=self.w, base_margin=self.margin)
|
||||
|
||||
def __repr__(self):
|
||||
return self.name
|
||||
@@ -206,16 +207,23 @@ _unweighted_datasets_strategy = strategies.sampled_from(
|
||||
|
||||
|
||||
@strategies.composite
|
||||
def _dataset_and_weight(draw):
|
||||
def _dataset_weight_margin(draw):
|
||||
data = draw(_unweighted_datasets_strategy)
|
||||
if draw(strategies.booleans()):
|
||||
data.w = draw(arrays(np.float64, (len(data.y)), elements=strategies.floats(0.1, 2.0)))
|
||||
if draw(strategies.booleans()):
|
||||
num_class = 1
|
||||
if data.objective == "multi:softmax":
|
||||
num_class = int(np.max(data.y) + 1)
|
||||
data.margin = draw(
|
||||
arrays(np.float64, (len(data.y) * num_class), elements=strategies.floats(0.5, 1.0)))
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# A strategy for drawing from a set of example datasets
|
||||
# May add random weights to the dataset
|
||||
dataset_strategy = _dataset_and_weight()
|
||||
dataset_strategy = _dataset_weight_margin()
|
||||
|
||||
|
||||
def non_increasing(L, tolerance=1e-4):
|
||||
|
||||
Reference in New Issue
Block a user