Initial support for multioutput regression. (#7514)
* Add num target model parameter, which is configured from input labels. * Change elementwise metric and indexing for weights. * Add demo. * Add tests.
This commit is contained in:
@@ -12,9 +12,9 @@
|
||||
#include "xgboost/json.h"
|
||||
#include "../../src/common/io.h"
|
||||
#include "../../src/common/random.h"
|
||||
#include "../../src/common/linalg_op.h"
|
||||
|
||||
namespace xgboost {
|
||||
|
||||
TEST(Learner, Basic) {
|
||||
using Arg = std::pair<std::string, std::string>;
|
||||
auto args = {Arg("tree_method", "exact")};
|
||||
@@ -278,6 +278,7 @@ TEST(Learner, GPUConfiguration) {
|
||||
labels[i] = i;
|
||||
}
|
||||
p_dmat->Info().labels.Data()->HostVector() = labels;
|
||||
p_dmat->Info().labels.Reshape(kRows);
|
||||
{
|
||||
std::unique_ptr<Learner> learner {Learner::Create(mat)};
|
||||
learner->SetParams({Arg{"booster", "gblinear"},
|
||||
@@ -424,4 +425,28 @@ TEST(Learner, FeatureInfo) {
|
||||
ASSERT_TRUE(std::equal(out_types.begin(), out_types.end(), types.begin()));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Learner, MultiTarget) {
|
||||
size_t constexpr kRows{128}, kCols{10}, kTargets{3};
|
||||
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
|
||||
m->Info().labels.Reshape(kRows, kTargets);
|
||||
linalg::ElementWiseKernelHost(m->Info().labels.HostView(), omp_get_max_threads(),
|
||||
[](auto i, auto) { return i; });
|
||||
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->Configure();
|
||||
|
||||
Json model{Object()};
|
||||
learner->SaveModel(&model);
|
||||
ASSERT_EQ(get<String>(model["learner"]["learner_model_param"]["num_target"]),
|
||||
std::to_string(kTargets));
|
||||
}
|
||||
{
|
||||
std::unique_ptr<Learner> learner{Learner::Create({m})};
|
||||
learner->SetParam("objective", "multi:softprob");
|
||||
// unsupported objective.
|
||||
EXPECT_THROW({ learner->Configure(); }, dmlc::Error);
|
||||
}
|
||||
}
|
||||
} // namespace xgboost
|
||||
|
||||
Reference in New Issue
Block a user