Use matrix for gradient. (#9508)
- Use the `linalg::Matrix` for storing gradients. - New API for the custom objective. - Custom objective for multi-class/multi-target is now required to return the correct shape. - Custom objective for Python can accept arrays with any strides. (row-major, column-major)
This commit is contained in:
@@ -202,13 +202,13 @@ TEST(QuantileHist, PartitionerColSplit) { TestColumnSplitPartitioner<CPUExpandEn
|
||||
TEST(QuantileHist, MultiPartitionerColSplit) { TestColumnSplitPartitioner<MultiExpandEntry>(3); }
|
||||
|
||||
namespace {
|
||||
void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
|
||||
void VerifyColumnSplit(Context const* ctx, bst_row_t rows, bst_feature_t cols, bst_target_t n_targets,
|
||||
RegTree const& expected_tree) {
|
||||
auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
|
||||
auto p_gradients = GenerateGradients(rows, n_targets);
|
||||
Context ctx;
|
||||
linalg::Matrix<GradientPair> gpair = GenerateRandomGradients(ctx, rows, n_targets);
|
||||
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
|
||||
std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker", ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
|
||||
std::unique_ptr<DMatrix> sliced{Xy->SliceCol(collective::GetWorldSize(), collective::GetRank())};
|
||||
@@ -217,7 +217,7 @@ void VerifyColumnSplit(bst_row_t rows, bst_feature_t cols, bst_target_t n_target
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Configure(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), sliced.get(), position, {&tree});
|
||||
updater->Update(¶m, &gpair, sliced.get(), position, {&tree});
|
||||
|
||||
Json json{Object{}};
|
||||
tree.SaveModel(&json);
|
||||
@@ -232,21 +232,21 @@ void TestColumnSplit(bst_target_t n_targets) {
|
||||
|
||||
RegTree expected_tree{n_targets, kCols};
|
||||
ObjInfo task{ObjInfo::kRegression};
|
||||
Context ctx;
|
||||
{
|
||||
auto Xy = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
|
||||
auto p_gradients = GenerateGradients(kRows, n_targets);
|
||||
Context ctx;
|
||||
auto gpair = GenerateRandomGradients(&ctx, kRows, n_targets);
|
||||
std::unique_ptr<TreeUpdater> updater{
|
||||
TreeUpdater::Create("grow_quantile_histmaker", &ctx, &task)};
|
||||
std::vector<HostDeviceVector<bst_node_t>> position(1);
|
||||
TrainParam param;
|
||||
param.Init(Args{});
|
||||
updater->Configure(Args{});
|
||||
updater->Update(¶m, p_gradients.get(), Xy.get(), position, {&expected_tree});
|
||||
updater->Update(¶m, &gpair, Xy.get(), position, {&expected_tree});
|
||||
}
|
||||
|
||||
auto constexpr kWorldSize = 2;
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, kRows, kCols, n_targets,
|
||||
RunWithInMemoryCommunicator(kWorldSize, VerifyColumnSplit, &ctx, kRows, kCols, n_targets,
|
||||
std::cref(expected_tree));
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
Reference in New Issue
Block a user