Use matrix for gradient. (#9508)

- Use the `linalg::Matrix` for storing gradients.
- New API for the custom objective.
- Custom objective for multi-class/multi-target is now required to return the correct shape.
- Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
Jiaming Yuan
2023-08-24 05:29:52 +08:00
committed by GitHub
parent 6103dca0bb
commit 972730cde0
77 changed files with 1052 additions and 651 deletions

View File

@@ -202,13 +202,13 @@ TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner<CPUExpandEn
TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); }
namespace {
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
void VerifyColumnSplit(Context const* ctx, bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
RegTree const& expected_tree) {
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
auto p_gradients = GenerateGradients(rows, n_targets);
Context ctx;
linalg::Matrix<GradientPair> gpair = GenerateRandomGradients(ctx, rows, n_targets);
ObjInfo task{ObjInfo::kRegression};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1);
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
@@ -217,7 +217,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_target
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), sliced.get(), position, {&tree});
updater->Update(&param, &gpair, sliced.get(), position, {&tree});
Json json{Object{}};
tree.SaveModel(&json);
@@ -232,21 +232,21 @@ void TestColumnSplit(bst_target_t n_targets) {
RegTree expected_tree{n_targets, kCols};
ObjInfo task{ObjInfo::kRegression};
Context ctx;
{
auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
auto p_gradients = GenerateGradients(kRows, n_targets);
Context ctx;
auto gpair = GenerateRandomGradients(&ctx, kRows, n_targets);
std::unique_ptr<TreeUpdater> updater{
TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
std::vector<HostDeviceVector<bst_node_t>> position(1);
TrainParam param;
param.Init(Args{});
updater->Configure(Args{});
updater->Update(&param, p_gradients.get(), Xy.get(), position, {&expected_tree});
updater->Update(&param, &gpair, Xy.get(), position, {&expected_tree});
}
auto constexpr kWorldSize = 2;
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, n_targets,
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, &ctx, kRows, kCols, n_targets,
std::cref(expected_tree));
}
} // anonymous namespace