Handle missing values in one hot splits. (#7934)
This commit is contained in:
parent
18a38f7ca0
commit
606be9e663
@ -45,14 +45,72 @@ class HistEvaluator {
|
|||||||
// then - there are no missing values
|
// then - there are no missing values
|
||||||
// else - there are missing values
|
// else - there are missing values
|
||||||
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
|
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
|
||||||
if (e.GetGrad() == snode.stats.GetGrad() &&
|
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
|
||||||
e.GetHess() == snode.stats.GetHess()) {
|
|
||||||
return false;
|
return false;
|
||||||
} else {
|
} else {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsValid(GradStats const &left, GradStats const &right) const {
|
||||||
|
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
|
||||||
|
* create a pseudo-category for missing value but here we just do a complete scan
|
||||||
|
* to avoid making specialized histogram bin.
|
||||||
|
*/
|
||||||
|
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
|
||||||
|
bst_feature_t fidx, bst_node_t nidx,
|
||||||
|
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
|
||||||
|
SplitEntry *p_best) const {
|
||||||
|
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
|
||||||
|
const std::vector<bst_float> &cut_val = cut.Values();
|
||||||
|
|
||||||
|
bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
|
||||||
|
bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
|
||||||
|
bst_bin_t n_bins = iend - ibegin;
|
||||||
|
|
||||||
|
GradStats left_sum;
|
||||||
|
GradStats right_sum;
|
||||||
|
// best split so far
|
||||||
|
SplitEntry best;
|
||||||
|
|
||||||
|
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
|
||||||
|
auto feature_sum = GradStats{
|
||||||
|
std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
|
||||||
|
GradStats missing;
|
||||||
|
auto const &parent = snode_[nidx];
|
||||||
|
missing.SetSubstract(parent.stats, feature_sum);
|
||||||
|
|
||||||
|
for (bst_bin_t i = ibegin; i != iend; i += 1) {
|
||||||
|
auto split_pt = cut_val[i];
|
||||||
|
|
||||||
|
// missing on left (treat missing as other categories)
|
||||||
|
right_sum = GradStats{hist[i]};
|
||||||
|
left_sum.SetSubstract(parent.stats, right_sum);
|
||||||
|
if (IsValid(left_sum, right_sum)) {
|
||||||
|
auto missing_left_chg = static_cast<float>(
|
||||||
|
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||||
|
parent.root_gain);
|
||||||
|
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
// missing on right (treat missing as chosen category)
|
||||||
|
left_sum.SetSubstract(left_sum, missing);
|
||||||
|
right_sum.Add(missing);
|
||||||
|
if (IsValid(left_sum, right_sum)) {
|
||||||
|
auto missing_right_chg = static_cast<float>(
|
||||||
|
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
|
||||||
|
parent.root_gain);
|
||||||
|
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p_best->Update(best);
|
||||||
|
}
|
||||||
|
|
||||||
// Enumerate/Scan the split values of specific feature
|
// Enumerate/Scan the split values of specific feature
|
||||||
// Returns the sum of gradients corresponding to the data points that contains
|
// Returns the sum of gradients corresponding to the data points that contains
|
||||||
// a non-missing value for the particular feature fid.
|
// a non-missing value for the particular feature fid.
|
||||||
@ -102,9 +160,7 @@ class HistEvaluator {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kOneHot: {
|
case kOneHot: {
|
||||||
// not-chosen categories go to left
|
std::terminate(); // unreachable
|
||||||
right_sum = GradStats{hist[i]};
|
|
||||||
left_sum.SetSubstract(parent.stats, right_sum);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kPart: {
|
case kPart: {
|
||||||
@ -151,7 +207,7 @@ class HistEvaluator {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kOneHot: {
|
case kOneHot: {
|
||||||
split_pt = cut_val[i];
|
std::terminate(); // unreachable
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case kPart: {
|
case kPart: {
|
||||||
@ -188,7 +244,6 @@ class HistEvaluator {
|
|||||||
// Normal, accumulated to left
|
// Normal, accumulated to left
|
||||||
return left_sum;
|
return left_sum;
|
||||||
case kOneHot:
|
case kOneHot:
|
||||||
// Doesn't matter, not accumulating.
|
|
||||||
return {};
|
return {};
|
||||||
case kPart:
|
case kPart:
|
||||||
// Accumulated to right due to chosen cats go to right.
|
// Accumulated to right due to chosen cats go to right.
|
||||||
@ -242,8 +297,7 @@ class HistEvaluator {
|
|||||||
if (is_cat) {
|
if (is_cat) {
|
||||||
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
|
||||||
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
|
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
|
||||||
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
|
||||||
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
|
|
||||||
} else {
|
} else {
|
||||||
std::vector<size_t> sorted_idx(n_bins);
|
std::vector<size_t> sorted_idx(n_bins);
|
||||||
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
|
||||||
|
|||||||
@ -214,6 +214,9 @@ class TestTreeMethod:
|
|||||||
self.run_max_cat(tree_method)
|
self.run_max_cat(tree_method)
|
||||||
|
|
||||||
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
|
||||||
|
USE_ONEHOT = np.iinfo(np.int32).max
|
||||||
|
USE_PART = 1
|
||||||
|
|
||||||
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
onehot, label = tm.make_categorical(rows, cols, cats, True)
|
||||||
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
cat, _ = tm.make_categorical(rows, cols, cats, False)
|
||||||
|
|
||||||
@ -221,10 +224,9 @@ class TestTreeMethod:
|
|||||||
by_builtin_results = {}
|
by_builtin_results = {}
|
||||||
|
|
||||||
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
|
||||||
|
parameters = {"tree_method": tree_method, "predictor": predictor}
|
||||||
# Use one-hot exclusively
|
# Use one-hot exclusively
|
||||||
parameters = {
|
parameters["max_cat_to_onehot"] = USE_ONEHOT
|
||||||
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
|
|
||||||
}
|
|
||||||
|
|
||||||
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
m = xgb.DMatrix(onehot, label, enable_categorical=False)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
@ -257,7 +259,8 @@ class TestTreeMethod:
|
|||||||
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])
|
||||||
|
|
||||||
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
|
by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
|
||||||
parameters["max_cat_to_onehot"] = 1
|
# switch to partition-based splits
|
||||||
|
parameters["max_cat_to_onehot"] = USE_PART
|
||||||
parameters["reg_lambda"] = 0
|
parameters["reg_lambda"] = 0
|
||||||
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
m = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||||
xgb.train(
|
xgb.train(
|
||||||
@ -284,6 +287,27 @@ class TestTreeMethod:
|
|||||||
)
|
)
|
||||||
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
|
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping
|
||||||
|
|
||||||
|
# test with missing values
|
||||||
|
cat, label = tm.make_categorical(
|
||||||
|
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
|
||||||
|
)
|
||||||
|
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
|
||||||
|
evals_result = {}
|
||||||
|
# Test with onehot splits
|
||||||
|
parameters["max_cat_to_onehot"] = USE_ONEHOT
|
||||||
|
booster = xgb.train(
|
||||||
|
parameters,
|
||||||
|
Xy,
|
||||||
|
num_boost_round=16,
|
||||||
|
evals=[(Xy, "Train")],
|
||||||
|
evals_result=evals_result
|
||||||
|
)
|
||||||
|
assert tm.non_increasing(evals_result["Train"]["rmse"])
|
||||||
|
y_predt = booster.predict(Xy)
|
||||||
|
|
||||||
|
rmse = tm.root_mean_square(label, y_predt)
|
||||||
|
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])
|
||||||
|
|
||||||
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
@given(strategies.integers(10, 400), strategies.integers(3, 8),
|
||||||
strategies.integers(1, 2), strategies.integers(4, 7))
|
strategies.integers(1, 2), strategies.integers(4, 7))
|
||||||
@settings(deadline=None, print_blob=True)
|
@settings(deadline=None, print_blob=True)
|
||||||
|
|||||||
@ -302,7 +302,7 @@ def get_mq2008(dpath):
|
|||||||
|
|
||||||
@memory.cache
|
@memory.cache
|
||||||
def make_categorical(
|
def make_categorical(
|
||||||
n_samples: int, n_features: int, n_categories: int, onehot: bool
|
n_samples: int, n_features: int, n_categories: int, onehot: bool, sparsity=0.0,
|
||||||
):
|
):
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
@ -325,6 +325,13 @@ def make_categorical(
|
|||||||
for col in df.columns:
|
for col in df.columns:
|
||||||
df[col] = df[col].cat.set_categories(categories)
|
df[col] = df[col].cat.set_categories(categories)
|
||||||
|
|
||||||
|
if sparsity > 0.0:
|
||||||
|
for i in range(n_features):
|
||||||
|
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity))
|
||||||
|
df.iloc[index, i] = np.NaN
|
||||||
|
assert df.iloc[:, i].isnull().values.any()
|
||||||
|
assert n_categories == np.unique(df.dtypes[i].categories).size
|
||||||
|
|
||||||
if onehot:
|
if onehot:
|
||||||
return pd.get_dummies(df), label
|
return pd.get_dummies(df), label
|
||||||
return df, label
|
return df, label
|
||||||
@ -538,6 +545,12 @@ def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
|||||||
return np.sum(r)
|
return np.sum(r)
|
||||||
|
|
||||||
|
|
||||||
|
def root_mean_square(y_true: np.ndarray, y_score: np.ndarray) -> float:
|
||||||
|
err = y_score - y_true
|
||||||
|
rmse = np.sqrt(np.dot(err, err) / y_score.size)
|
||||||
|
return rmse
|
||||||
|
|
||||||
|
|
||||||
def softmax(x):
|
def softmax(x):
|
||||||
e = np.exp(x)
|
e = np.exp(x)
|
||||||
return e / np.sum(e)
|
return e / np.sum(e)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user