Reduce base margin to 2 dim for now. (#7455)

This commit is contained in:
Jiaming Yuan
2021-11-27 00:46:13 +08:00
committed by GitHub
parent bf7bb575b4
commit 557ffc4bf5
7 changed files with 33 additions and 33 deletions

View File

@@ -61,7 +61,8 @@ Predictor* Predictor::Create(
return p_predictor;
}
void ValidateBaseMarginShape(linalg::Tensor<float, 3> const& margin, bst_row_t n_samples,
template <int32_t D>
void ValidateBaseMarginShape(linalg::Tensor<float, D> const& margin, bst_row_t n_samples,
bst_group_t n_groups) {
// FIXME: Bindings other than Python doesn't have shape.
std::string expected{"Invalid shape of base_margin. Expected: (" + std::to_string(n_samples) +