xgboost/src/data/validation.h
Jiaming Yuan 55ee272ea8
Extend array interface to handle ndarray. (#7434)
* Extend array interface to handle ndarray.

The `ArrayInterface` class is extended to support multi-dim array inputs. Previously this
class handles only 2-dim (vector is also matrix).  This PR specifies the expected
dimension at compile-time and the array interface can perform various checks automatically
for input data. Also, adapters like CSR are more rigorous about their input.  Lastly, row
vector and column vector are handled without intervention from the caller.
2021-11-16 09:52:15 +08:00

41 lines
1020 B
C++

/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_DATA_VALIDATION_H_
#define XGBOOST_DATA_VALIDATION_H_
#include <cmath>
#include <vector>
#include "xgboost/base.h"
#include "xgboost/logging.h"
namespace xgboost {
namespace data {
struct LabelsCheck {
XGBOOST_DEVICE bool operator()(float y) {
#if defined(__CUDA_ARCH__)
return ::isnan(y) || ::isinf(y);
#else
return std::isnan(y) || std::isinf(y);
#endif
}
};
struct WeightsCheck {
XGBOOST_DEVICE bool operator()(float w) { return LabelsCheck{}(w) || w < 0; } // NOLINT
};
inline void ValidateQueryGroup(std::vector<bst_group_t> const &group_ptr_) {
bool valid_query_group = true;
for (size_t i = 1; i < group_ptr_.size(); ++i) {
valid_query_group = valid_query_group && group_ptr_[i] >= group_ptr_[i - 1];
if (XGBOOST_EXPECT(!valid_query_group, false)) {
break;
}
}
CHECK(valid_query_group) << "Invalid group structure.";
}
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_VALIDATION_H_