xgboost/src/tree/constraints.h

64 lines
1.9 KiB
C++

/**
* Copyright 2018-2023 by Contributors
*/
#ifndef XGBOOST_TREE_CONSTRAINTS_H_
#define XGBOOST_TREE_CONSTRAINTS_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "param.h"
#include "xgboost/base.h"
namespace xgboost {
/*!
* \brief Feature interaction constraint implementation for CPU tree updaters.
*
* The interface is similar to the one for GPU Hist.
*/
class FeatureInteractionConstraintHost {
protected:
// interaction_constraints_[constraint_id] contains a single interaction
// constraint, which specifies a group of feature IDs that can interact
// with each other
std::vector< std::unordered_set<bst_feature_t> > interaction_constraints_;
// int_cont_[nid] contains the set of all feature IDs that are allowed to
// be used for a split at node nid
std::vector< std::unordered_set<bst_feature_t> > node_constraints_;
// splits_[nid] contains the set of all feature IDs that have been used for
// splits in node nid and its parents
std::vector< std::unordered_set<bst_feature_t> > splits_;
// string passed by user.
std::string interaction_constraint_str_;
// number of features in DMatrix/Booster
bst_feature_t n_features_;
bool enabled_{false};
void SplitImpl(int32_t node_id, bst_feature_t feature_id, bst_node_t left_id,
bst_node_t right_id);
public:
FeatureInteractionConstraintHost() = default;
void Split(int32_t node_id, bst_feature_t feature_id, bst_node_t left_id,
bst_node_t right_id) {
if (!enabled_) {
return;
} else {
this->SplitImpl(node_id, feature_id, left_id, right_id);
}
}
bool Query(bst_node_t nid, bst_feature_t fid) const {
if (!enabled_) { return true; }
return node_constraints_.at(nid).find(fid) != node_constraints_.at(nid).cend();
}
void Reset();
void Configure(tree::TrainParam const& param, bst_feature_t const n_features);
};
} // namespace xgboost
#endif // XGBOOST_TREE_CONSTRAINTS_H_